xref: /aosp_15_r20/external/pytorch/aten/src/ATen/PTThreadPool.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/Parallel.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/thread_pool.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace at {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker class TORCH_API PTThreadPool : public c10::ThreadPool {
9*da0073e9SAndroid Build Coastguard Worker  public:
10*da0073e9SAndroid Build Coastguard Worker   explicit PTThreadPool(int pool_size, int numa_node_id = -1)
11*da0073e9SAndroid Build Coastguard Worker       : c10::ThreadPool(pool_size, numa_node_id, []() {
12*da0073e9SAndroid Build Coastguard Worker           c10::setThreadName("PTThreadPool");
13*da0073e9SAndroid Build Coastguard Worker           at::init_num_threads();
14*da0073e9SAndroid Build Coastguard Worker         }) {}
15*da0073e9SAndroid Build Coastguard Worker };
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker } // namespace at
18