xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TFRT_RUNTIME_TF_THREADPOOL_CONCURRENT_WORK_QUEUE_H_
16 #define TENSORFLOW_CORE_TFRT_RUNTIME_TF_THREADPOOL_CONCURRENT_WORK_QUEUE_H_
17 
18 #include <string>
19 
20 #include "tensorflow/core/platform/cpu_info.h"
21 #include "tensorflow/core/platform/status.h"
22 #include "tensorflow/core/platform/threadpool_interface.h"
23 #include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
24 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
25 #include "tfrt/host_context/concurrent_work_queue.h"  // from @tf_runtime
26 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
27 #include "tfrt/host_context/task_function.h"  // from @tf_runtime
28 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
29 
30 namespace tensorflow {
31 namespace tfrt_stub {
32 
33 // This class defines a work queue based on the WorkQueueInterface that uses the
34 // Tensorflow threadpools to execute inter-op and intra-op closures.
35 class TfThreadPoolWorkQueue : public WorkQueueInterface {
36  public:
TfThreadPoolWorkQueue(tensorflow::thread::ThreadPoolInterface * intra_op_threadpool,tensorflow::thread::ThreadPoolInterface * inter_op_threadpool)37   TfThreadPoolWorkQueue(
38       tensorflow::thread::ThreadPoolInterface* intra_op_threadpool,
39       tensorflow::thread::ThreadPoolInterface* inter_op_threadpool)
40       : TfThreadPoolWorkQueue(/*id=*/0, intra_op_threadpool,
41                               inter_op_threadpool) {}
42 
TfThreadPoolWorkQueue(int64_t id,tensorflow::thread::ThreadPoolInterface * intra_op_threadpool,tensorflow::thread::ThreadPoolInterface * inter_op_threadpool)43   TfThreadPoolWorkQueue(
44       int64_t id, tensorflow::thread::ThreadPoolInterface* intra_op_threadpool,
45       tensorflow::thread::ThreadPoolInterface* inter_op_threadpool)
46       : WorkQueueInterface(id),
47         intra_op_threadpool_(intra_op_threadpool),
48         inter_op_threadpool_(inter_op_threadpool) {}
49 
50   StatusOr<std::unique_ptr<WorkQueueInterface>> InitializeRequest(
51       ::tfrt::RequestContextBuilder* request_context_builder,
52       tensorflow::thread::ThreadPoolInterface** intra_op_threadpool)
53       const override;
54 
GetParallelismLevel()55   int GetParallelismLevel() const override {
56     return tensorflow::port::MaxParallelism();
57   }
name()58   std::string name() const override { return "TfThreadPoolWorkQueue"; }
59 
60   void AddTask(tfrt::TaskFunction work) override;
61 
62   llvm::Optional<tfrt::TaskFunction> AddBlockingTask(
63       tfrt::TaskFunction work, bool allow_queuing) override;
64 
65   void Quiesce() override;
66 
67   void Await(
68       tfrt::ArrayRef<::tfrt::RCReference<::tfrt::AsyncValue>> values) override;
69 
70   bool IsInWorkerThread() const override;
71 
72  private:
73   tensorflow::thread::ThreadPoolInterface* intra_op_threadpool_ = nullptr;
74   tensorflow::thread::ThreadPoolInterface* inter_op_threadpool_ = nullptr;
75 };
76 
77 // Create a default TfThreadPoolWorkQueue that is implemented by
78 // tensorflow::thread::ThreadPool. `num_inter_op_threads` and
79 // `num_intra_op_threads` must be larger than zero.
80 std::unique_ptr<TfThreadPoolWorkQueue> CreateDefaultTfThreadPoolWorkQueue(
81     int num_inter_op_threads, int num_intra_op_threads);
82 
83 }  // namespace tfrt_stub
84 }  // namespace tensorflow
85 
86 #endif  // TENSORFLOW_CORE_TFRT_RUNTIME_TF_THREADPOOL_CONCURRENT_WORK_QUEUE_H_
87