xref: /aosp_15_r20/external/pytorch/caffe2/utils/threadpool/pthreadpool-cpp.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #ifdef USE_PTHREADPOOL
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #ifdef USE_INTERNAL_PTHREADPOOL_IMPL
6*da0073e9SAndroid Build Coastguard Worker #include <caffe2/utils/threadpool/pthreadpool.h>
7*da0073e9SAndroid Build Coastguard Worker #else
8*da0073e9SAndroid Build Coastguard Worker #include <pthreadpool.h>
9*da0073e9SAndroid Build Coastguard Worker #endif
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker #include <functional>
12*da0073e9SAndroid Build Coastguard Worker #include <memory>
13*da0073e9SAndroid Build Coastguard Worker #include <mutex>
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker namespace caffe2 {
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker class PThreadPool final {
18*da0073e9SAndroid Build Coastguard Worker  public:
19*da0073e9SAndroid Build Coastguard Worker   explicit PThreadPool(size_t thread_count);
20*da0073e9SAndroid Build Coastguard Worker   ~PThreadPool() = default;
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker   PThreadPool(const PThreadPool&) = delete;
23*da0073e9SAndroid Build Coastguard Worker   PThreadPool& operator=(const PThreadPool&) = delete;
24*da0073e9SAndroid Build Coastguard Worker 
25*da0073e9SAndroid Build Coastguard Worker   PThreadPool(PThreadPool&&) = delete;
26*da0073e9SAndroid Build Coastguard Worker   PThreadPool& operator=(PThreadPool&&) = delete;
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker   size_t get_thread_count() const;
29*da0073e9SAndroid Build Coastguard Worker   void set_thread_count(size_t thread_count);
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker   // Run, in parallel, function fn(task_id) over task_id in range [0, range).
32*da0073e9SAndroid Build Coastguard Worker   // This function is blocking.  All input is processed by the time it returns.
33*da0073e9SAndroid Build Coastguard Worker   void run(const std::function<void(size_t)>& fn, size_t range);
34*da0073e9SAndroid Build Coastguard Worker 
35*da0073e9SAndroid Build Coastguard Worker  private:
36*da0073e9SAndroid Build Coastguard Worker   friend pthreadpool_t pthreadpool_();
37*da0073e9SAndroid Build Coastguard Worker 
38*da0073e9SAndroid Build Coastguard Worker  private:
39*da0073e9SAndroid Build Coastguard Worker   mutable std::mutex mutex_;
40*da0073e9SAndroid Build Coastguard Worker   std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_;
41*da0073e9SAndroid Build Coastguard Worker };
42*da0073e9SAndroid Build Coastguard Worker 
43*da0073e9SAndroid Build Coastguard Worker // Return a singleton instance of PThreadPool for ATen/TH multithreading.
44*da0073e9SAndroid Build Coastguard Worker PThreadPool* pthreadpool();
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker // Exposes the underlying implementation of PThreadPool.
47*da0073e9SAndroid Build Coastguard Worker // Only for use in external libraries so as to unify threading across
48*da0073e9SAndroid Build Coastguard Worker // internal (i.e. ATen, etc.) and external (e.g. NNPACK, QNNPACK, XNNPACK)
49*da0073e9SAndroid Build Coastguard Worker // use cases.
50*da0073e9SAndroid Build Coastguard Worker pthreadpool_t pthreadpool_();
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker } // namespace caffe2
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker #endif /* USE_PTHREADPOOL */
55