1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <functional> 12 #include <memory> 13 #include <mutex> 14 15 #include <pthreadpool.h> 16 17 namespace executorch::extension::threadpool { 18 19 class ThreadPool final { 20 public: 21 explicit ThreadPool(size_t thread_count = 0); 22 ~ThreadPool() = default; 23 24 // Make threadpool non copyable 25 // Non-copyable: threadpool cannot be copied because it will 26 // effectively require cloning of threadpool. 27 // Cloning can be done by just calling create_thread_pool. 28 ThreadPool(const ThreadPool&) = delete; 29 ThreadPool& operator=(const ThreadPool&) = delete; 30 31 // Make threadpool non-movable. 32 ThreadPool(ThreadPool&&) = delete; 33 ThreadPool& operator=(ThreadPool&&) = delete; 34 35 size_t get_thread_count() const; 36 37 /** 38 * INTERNAL: Resets the threadpool by creating a new threadpool with requested 39 * # of threads. This is not a thread safe call. When calling this method, 40 * threads of the threadpool might be doing some work. Some other code may 41 * also be holding on to the threadpool pointer, that is no longer valid. This 42 * is a private API, which will later be replaced by something that allows 43 * creating of threadpool with requested size and use such a threadpool with 44 * backend delegates, custom ops or optimized lib. 45 */ 46 [[deprecated("This API is experimental and may change without notice.")]] 47 bool _unsafe_reset_threadpool(uint32_t num_threads); 48 49 /** 50 * Run, in parallel, function fn(task_id) over task_id in range [0, range). 51 * This function is blocking. All input is processed by the time it returns. 52 * NoThreadPoolGuard (see threadpool_guard.h) can used to disable use of 53 * multiple threads with the scope of the guard When NoThreadPoolGuard is not 54 * used all calls to run method are serialized. 55 */ 56 void run(const std::function<void(size_t)>& fn, size_t range); 57 58 private: 59 friend pthreadpool_t get_pthreadpool(); 60 61 private: 62 // This mutex is used inside get_thread_count API but it is not really needed 63 // since data members of ThreadPool objects are not really mutable. 64 // TODO(kimishpatel): Figure out if we will allow set_num_threads API, in 65 // which case this mutex will be useful. Otherwise remove it. 66 mutable std::mutex mutex_; 67 std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_; 68 }; 69 70 /** 71 * Returns the singleton instance of ThreadPool for ATen/TH multithreading. 72 */ 73 ThreadPool* get_threadpool(); 74 75 /** 76 * Returns the underlying pthreadpool instance used by the implementation of 77 * ThreadPool returned by `get_threadpool()`. Only for use in external libraries 78 * so as to unify threading across internal (i.e. ATen, etc.) and external (e.g. 79 * NNPACK, QNNPACK, XNNPACK) use cases. 80 */ 81 pthreadpool_t get_pthreadpool(); 82 83 } // namespace executorch::extension::threadpool 84 85 namespace torch::executorch::threadpool { // DEPRECATED 86 // TODO(T197294990): Remove these deprecated aliases once all users have moved 87 // to the new `::executorch` namespaces. Note that threadpool incorrectly used 88 // the namespace `torch::executorch` instead of `torch::executor`. 89 using ::executorch::extension::threadpool::get_pthreadpool; // DEPRECATED 90 using ::executorch::extension::threadpool::get_threadpool; // DEPRECATED 91 using ::executorch::extension::threadpool::ThreadPool; // DEPRECATED 92 } // namespace torch::executorch::threadpool 93