xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/runtime/tf_threadpool_concurrent_work_queue.h"
16 
17 #include <utility>
18 
19 #include "llvm/ADT/None.h"
20 #include "tensorflow/core/platform/errors.h"
21 #include "tensorflow/core/platform/status.h"
22 #include "tensorflow/core/platform/threadpool.h"
23 #include "tensorflow/core/platform/threadpool_interface.h"
24 #include "tensorflow/core/tfrt/utils/thread_pool.h"
25 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
26 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
27 #include "tfrt/host_context/task_function.h"  // from @tf_runtime
28 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
29 #include "tfrt/support/latch.h"  // from @tf_runtime
30 
31 namespace tensorflow {
32 namespace tfrt_stub {
33 
34 using ::tensorflow::thread::ThreadPoolInterface;
35 
36 StatusOr<std::unique_ptr<WorkQueueInterface>>
InitializeRequest(::tfrt::RequestContextBuilder * request_context_builder,ThreadPoolInterface ** intra_op_threadpool) const37 TfThreadPoolWorkQueue::InitializeRequest(
38     ::tfrt::RequestContextBuilder* request_context_builder,
39     ThreadPoolInterface** intra_op_threadpool) const {
40   DCHECK(intra_op_threadpool);
41   *intra_op_threadpool = intra_op_threadpool_;
42 
43   return {std::make_unique<TfThreadPoolWorkQueue>(request_context_builder->id(),
44                                                   intra_op_threadpool_,
45                                                   inter_op_threadpool_)};
46 }
47 
AddTask(tfrt::TaskFunction work)48 void TfThreadPoolWorkQueue::AddTask(tfrt::TaskFunction work) {
49   auto* copy = new tfrt::TaskFunction(
50       tensorflow::tfrt_stub::WrapWork(id(), "inter", std::move(work)));
51   inter_op_threadpool_->Schedule([copy] {
52     (*copy)();
53     delete copy;
54   });
55 }
56 
AddBlockingTask(tfrt::TaskFunction work,bool allow_queuing)57 llvm::Optional<tfrt::TaskFunction> TfThreadPoolWorkQueue::AddBlockingTask(
58     tfrt::TaskFunction work, bool allow_queuing) {
59   AddTask(std::move(work));
60   return llvm::None;
61 }
62 
Quiesce()63 void TfThreadPoolWorkQueue::Quiesce() {
64   // TODO(b/186668821): implement this
65   CHECK(false);  // Crash OK
66 }
67 
68 // From
69 // third_party/tensorflow/core/tfrt/run_handler_thread_pool/run_handler_concurrent_work_queue.cc
Await(tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values)70 void TfThreadPoolWorkQueue::Await(
71     tfrt::ArrayRef<tfrt::RCReference<tfrt::AsyncValue>> values) {
72   // We are done when values_remaining drops to zero.
73   tfrt::latch values_remaining(values.size());
74 
75   // As each value becomes available, we decrement the count.
76   for (auto& value : values) {
77     value->AndThen([&values_remaining]() { values_remaining.count_down(); });
78   }
79 
80   // Wait until all values are resolved.
81   values_remaining.wait();
82 }
83 
IsInWorkerThread() const84 bool TfThreadPoolWorkQueue::IsInWorkerThread() const {
85   // TODO(b/192247530): Check if we have cases it is not true.
86   return true;
87 }
88 
CreateDefaultTfThreadPoolWorkQueue(int num_inter_op_threads,int num_intra_op_threads)89 std::unique_ptr<TfThreadPoolWorkQueue> CreateDefaultTfThreadPoolWorkQueue(
90     int num_inter_op_threads, int num_intra_op_threads) {
91   struct ThreadPools {
92     TfThreadPool inter_op_threadpool;
93     TfThreadPool intra_op_threadpool;
94 
95     ThreadPools(int num_inter_op_threads, int num_intra_op_threads)
96         : inter_op_threadpool("default_work_queue_inter", num_inter_op_threads),
97           intra_op_threadpool("default_work_queue_intra",
98                               num_intra_op_threads) {}
99   };
100 
101   class Wrapper : public TfThreadPoolWorkQueue {
102    public:
103     explicit Wrapper(std::unique_ptr<ThreadPools> thread_pools)
104         : TfThreadPoolWorkQueue(&thread_pools->inter_op_threadpool,
105                                 &thread_pools->intra_op_threadpool),
106           thread_pools_(std::move(thread_pools)) {}
107 
108     ~Wrapper() override = default;
109 
110    private:
111     std::unique_ptr<ThreadPools> thread_pools_;
112   };
113 
114   return std::make_unique<Wrapper>(std::make_unique<ThreadPools>(
115       num_inter_op_threads, num_intra_op_threads));
116 }
117 
118 }  // namespace tfrt_stub
119 }  // namespace tensorflow
120