xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/async_runtime.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_ASYNC_RUNTIME_H_
17 #define XLA_RUNTIME_ASYNC_RUNTIME_H_
18 
19 #define EIGEN_USE_THREADS
20 
21 #include <cstddef>
22 #include <functional>
23 #include <utility>
24 
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
26 #include "llvm/ADT/STLExtras.h"
27 #include "tfrt/host_context/async_dispatch.h"  // from @tf_runtime
28 #include "tfrt/support/forward_decls.h"  // from @tf_runtime
29 
30 namespace mlir {
31 namespace runtime {
32 
33 // Async runtime in the XLA implements the MLIR async runtime API that supports
34 // the lowering of the `async` dialect to the LLVM and LLVM coroutines.
35 struct AsyncToken;
36 struct AsyncValue;
37 struct AsyncGroup;
38 
39 }  // namespace runtime
40 }  // namespace mlir
41 
42 namespace xla {
43 namespace runtime {
44 
45 // Forward declare a base class for async runtime objects.
46 class AsyncRuntimeObject;
47 
48 // Async task runner abstracts over the underlying thread pool (or concurrent
49 // work queue) implementation.
50 class AsyncTaskRunner {
51  public:
52   using Task = std::function<void()>;
53   virtual ~AsyncTaskRunner() = default;
54   virtual void Schedule(Task task) = 0;
55 };
56 
57 class AsyncRuntime {
58  public:
59   using Token = ::mlir::runtime::AsyncToken;
60   using Value = ::mlir::runtime::AsyncValue;
61   using Group = ::mlir::runtime::AsyncGroup;
62 
AsyncRuntime(AsyncTaskRunner * runner)63   explicit AsyncRuntime(AsyncTaskRunner* runner) : runner_(runner) {
64     assert(runner != nullptr && "async task runner must be not null");
65   }
66 
67   // We need a default constructor to define a thread local variable for async
68   // runtime passing between tasks (see implementation in async_runtime_api.cc).
AsyncRuntime()69   AsyncRuntime() : runner_(nullptr) {}
70 
71   // ------------------------------------------------------------------------ //
72   // Implicit AsyncRuntime propagation.
73   // ------------------------------------------------------------------------ //
74 
75   // Set the AsyncRuntime that will be implicitly propagated to all async tasks.
76   //
77   // On every launch of an async task (see `async_runtime_api.h`), current async
78   // runtime will be captured, and restored when the task will start its
79   // execution on a different thread.
80   static void Set(AsyncRuntime runtime);
81 
82   // Returns the current async runtime.
83   static AsyncRuntime& GetCurrentRuntime();
84 
85   // ------------------------------------------------------------------------ //
86   // Async Token API.
87   // ------------------------------------------------------------------------ //
88 
89   // Creates a new token in not-ready state.
90   Token* CreateToken();
91 
92   // Switches the token to the available state and runs all the awaiters.
93   void SetAvailable(Token* token);
94 
95   // Switches the token to the error state and runs all the awaiters.
96   void SetError(Token* token);
97 
98   // Returns `true` if the token is in the error state.
99   bool IsError(Token* token);
100 
101   // Blocks the caller thread until the token becomes ready.
102   void AwaitToken(Token* token);
103 
104   // ------------------------------------------------------------------------ //
105   // Async Value API.
106   // ------------------------------------------------------------------------ //
107 
108   // Creates a new value in not-ready state with a storage of the given size.
109   Value* CreateValue(size_t size, size_t alignment);
110 
111   // Switches the value to the available state and runs all the awaiters.
112   void SetAvailable(Value* value);
113 
114   // Switches the value to the error state and runs all the awaiters.
115   void SetError(Value* value);
116 
117   // Returns `true` if the value is in the error state.
118   bool IsError(Value* value);
119 
120   // Blocks the caller thread until the value becomes ready.
121   void AwaitValue(Value* value);
122 
123   // ------------------------------------------------------------------------ //
124   // Async Group API.
125   // ------------------------------------------------------------------------ //
126 
127   // Creates a new empty group.
128   Group* CreateGroup(int64_t size);
129 
130   // Adds `token` to the `group`.
131   size_t AddTokenToGroup(Group* group, Token* token);
132 
133   // Returns `true` if the group is in the error state (any of the tokens or
134   // values added to the group is in the error state).
135   bool IsError(Group* group);
136 
137   // Blocks the caller thread until the group becomes ready (all tokens that
138   // were added to the group are emplaced).
139   void AwaitGroup(Group* group);
140 
141   // ------------------------------------------------------------------------ //
142   // Execution and continuation based resumption API.
143   // ------------------------------------------------------------------------ //
144 
145   // Execute the callable `f` on a thread managed by the runtime.
146   template <typename F>
147   void Execute(F&& f);
148 
149   // Await operation that do not block the caller thread, but instead execute
150   // the callable `F` when the token/group become ready.
151   template <typename F>
152   void AwaitToken(Token* token, F&& f);
153   template <typename F>
154   void AwaitValue(Value* value, F&& f);
155   template <typename F>
156   void AwaitGroup(Group* group, F&& f);
157 
158   // ------------------------------------------------------------------------ //
159 
160   // Returns a pointer to the async value storage.
161   static void* GetStorage(Value* value);
162 
163   // Extracts async value that holds a chain owned by the value.
164   static tfrt::AsyncValue* GetAsyncValue(Value* value);
165 
166   // Extracts async value that is owned by the token.
167   static tfrt::AsyncValue* GetAsyncValue(Token* token);
168 
169   // Extracts async value that signals group completion.
170   static tfrt::AsyncValue* GetAsyncValue(Group* group);
171 
172   // Reference counting operations for the runtime objects.
173   static void AddRef(AsyncRuntimeObject* obj, unsigned count = 1);
174   static void DropRef(AsyncRuntimeObject* obj, unsigned count = 1);
175 
176   // Convert Token/Value/Group to AsyncRuntimeObject*;
177   static AsyncRuntimeObject* ToAsyncRuntimeObject(Token* token);
178   static AsyncRuntimeObject* ToAsyncRuntimeObject(Value* value);
179   static AsyncRuntimeObject* ToAsyncRuntimeObject(Group* group);
180 
runner()181   AsyncTaskRunner* runner() const { return runner_; }
182 
183  private:
184   // Blocks the caller thread until awaitable async value becomes available.
185   void Await(tfrt::AsyncValue* awaitable);
186 
187   AsyncTaskRunner* runner_;  // must outlive *this
188 };
189 
190 // A base class for all Async dialect types reference counted at runtime.
191 class AsyncRuntimeObject : public tfrt::ReferenceCounted<AsyncRuntimeObject> {
192  public:
193   using ReferenceCounted::ReferenceCounted;  // inherit constructors
194   virtual ~AsyncRuntimeObject() = default;
195 };
196 
197 template <typename F>
Execute(F && f)198 void AsyncRuntime::Execute(F&& f) {
199   runner_->Schedule(std::forward<F>(f));
200 }
201 
202 template <typename F>
AwaitToken(Token * token,F && f)203 void AsyncRuntime::AwaitToken(Token* token, F&& f) {
204   AsyncRuntime::GetAsyncValue(token)->AndThen(std::forward<F>(f));
205 }
206 
207 template <typename F>
AwaitValue(Value * value,F && f)208 void AsyncRuntime::AwaitValue(Value* value, F&& f) {
209   AsyncRuntime::GetAsyncValue(value)->AndThen(std::forward<F>(f));
210 }
211 
212 template <typename F>
AwaitGroup(Group * group,F && f)213 void AsyncRuntime::AwaitGroup(Group* group, F&& f) {
214   AsyncRuntime::GetAsyncValue(group)->AndThen(std::forward<F>(f));
215 }
216 
217 // Runs async tasks by enqueing them into the host context work queue.
218 class HostContextAsyncTaskRunner : public AsyncTaskRunner {
219  public:
HostContextAsyncTaskRunner(tfrt::HostContext * host)220   explicit HostContextAsyncTaskRunner(tfrt::HostContext* host) : host_(host) {}
Schedule(Task task)221   void Schedule(Task task) override { EnqueueWork(host_, std::move(task)); }
222 
223  private:
224   tfrt::HostContext* host_;
225 };
226 
227 // Runs async tasks by scheduling them into the Eigen thread pool.
228 class EigenThreadPoolAsyncTaskRunner : public AsyncTaskRunner {
229  public:
EigenThreadPoolAsyncTaskRunner(Eigen::ThreadPoolInterface * thread_pool)230   explicit EigenThreadPoolAsyncTaskRunner(
231       Eigen::ThreadPoolInterface* thread_pool)
232       : thread_pool_(thread_pool) {}
Schedule(Task task)233   void Schedule(Task task) override { thread_pool_->Schedule(std::move(task)); }
234 
235  private:
236   Eigen::ThreadPoolInterface* thread_pool_;
237 };
238 
239 }  // namespace runtime
240 }  // namespace xla
241 
242 #endif  // XLA_RUNTIME_ASYNC_RUNTIME_H_
243