1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef XLA_RUNTIME_ASYNC_RUNTIME_H_
17 #define XLA_RUNTIME_ASYNC_RUNTIME_H_
18
19 #define EIGEN_USE_THREADS
20
21 #include <cstddef>
22 #include <functional>
23 #include <utility>
24
25 #include "third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool"
26 #include "llvm/ADT/STLExtras.h"
27 #include "tfrt/host_context/async_dispatch.h" // from @tf_runtime
28 #include "tfrt/support/forward_decls.h" // from @tf_runtime
29
30 namespace mlir {
31 namespace runtime {
32
33 // Async runtime in the XLA implements the MLIR async runtime API that supports
34 // the lowering of the `async` dialect to the LLVM and LLVM coroutines.
35 struct AsyncToken;
36 struct AsyncValue;
37 struct AsyncGroup;
38
39 } // namespace runtime
40 } // namespace mlir
41
42 namespace xla {
43 namespace runtime {
44
45 // Forward declare a base class for async runtime objects.
46 class AsyncRuntimeObject;
47
48 // Async task runner abstracts over the underlying thread pool (or concurrent
49 // work queue) implementation.
50 class AsyncTaskRunner {
51 public:
52 using Task = std::function<void()>;
53 virtual ~AsyncTaskRunner() = default;
54 virtual void Schedule(Task task) = 0;
55 };
56
57 class AsyncRuntime {
58 public:
59 using Token = ::mlir::runtime::AsyncToken;
60 using Value = ::mlir::runtime::AsyncValue;
61 using Group = ::mlir::runtime::AsyncGroup;
62
AsyncRuntime(AsyncTaskRunner * runner)63 explicit AsyncRuntime(AsyncTaskRunner* runner) : runner_(runner) {
64 assert(runner != nullptr && "async task runner must be not null");
65 }
66
67 // We need a default constructor to define a thread local variable for async
68 // runtime passing between tasks (see implementation in async_runtime_api.cc).
AsyncRuntime()69 AsyncRuntime() : runner_(nullptr) {}
70
71 // ------------------------------------------------------------------------ //
72 // Implicit AsyncRuntime propagation.
73 // ------------------------------------------------------------------------ //
74
75 // Set the AsyncRuntime that will be implicitly propagated to all async tasks.
76 //
77 // On every launch of an async task (see `async_runtime_api.h`), current async
78 // runtime will be captured, and restored when the task will start its
79 // execution on a different thread.
80 static void Set(AsyncRuntime runtime);
81
82 // Returns the current async runtime.
83 static AsyncRuntime& GetCurrentRuntime();
84
85 // ------------------------------------------------------------------------ //
86 // Async Token API.
87 // ------------------------------------------------------------------------ //
88
89 // Creates a new token in not-ready state.
90 Token* CreateToken();
91
92 // Switches the token to the available state and runs all the awaiters.
93 void SetAvailable(Token* token);
94
95 // Switches the token to the error state and runs all the awaiters.
96 void SetError(Token* token);
97
98 // Returns `true` if the token is in the error state.
99 bool IsError(Token* token);
100
101 // Blocks the caller thread until the token becomes ready.
102 void AwaitToken(Token* token);
103
104 // ------------------------------------------------------------------------ //
105 // Async Value API.
106 // ------------------------------------------------------------------------ //
107
108 // Creates a new value in not-ready state with a storage of the given size.
109 Value* CreateValue(size_t size, size_t alignment);
110
111 // Switches the value to the available state and runs all the awaiters.
112 void SetAvailable(Value* value);
113
114 // Switches the value to the error state and runs all the awaiters.
115 void SetError(Value* value);
116
117 // Returns `true` if the value is in the error state.
118 bool IsError(Value* value);
119
120 // Blocks the caller thread until the value becomes ready.
121 void AwaitValue(Value* value);
122
123 // ------------------------------------------------------------------------ //
124 // Async Group API.
125 // ------------------------------------------------------------------------ //
126
127 // Creates a new empty group.
128 Group* CreateGroup(int64_t size);
129
130 // Adds `token` to the `group`.
131 size_t AddTokenToGroup(Group* group, Token* token);
132
133 // Returns `true` if the group is in the error state (any of the tokens or
134 // values added to the group is in the error state).
135 bool IsError(Group* group);
136
137 // Blocks the caller thread until the group becomes ready (all tokens that
138 // were added to the group are emplaced).
139 void AwaitGroup(Group* group);
140
141 // ------------------------------------------------------------------------ //
142 // Execution and continuation based resumption API.
143 // ------------------------------------------------------------------------ //
144
145 // Execute the callable `f` on a thread managed by the runtime.
146 template <typename F>
147 void Execute(F&& f);
148
149 // Await operation that do not block the caller thread, but instead execute
150 // the callable `F` when the token/group become ready.
151 template <typename F>
152 void AwaitToken(Token* token, F&& f);
153 template <typename F>
154 void AwaitValue(Value* value, F&& f);
155 template <typename F>
156 void AwaitGroup(Group* group, F&& f);
157
158 // ------------------------------------------------------------------------ //
159
160 // Returns a pointer to the async value storage.
161 static void* GetStorage(Value* value);
162
163 // Extracts async value that holds a chain owned by the value.
164 static tfrt::AsyncValue* GetAsyncValue(Value* value);
165
166 // Extracts async value that is owned by the token.
167 static tfrt::AsyncValue* GetAsyncValue(Token* token);
168
169 // Extracts async value that signals group completion.
170 static tfrt::AsyncValue* GetAsyncValue(Group* group);
171
172 // Reference counting operations for the runtime objects.
173 static void AddRef(AsyncRuntimeObject* obj, unsigned count = 1);
174 static void DropRef(AsyncRuntimeObject* obj, unsigned count = 1);
175
176 // Convert Token/Value/Group to AsyncRuntimeObject*;
177 static AsyncRuntimeObject* ToAsyncRuntimeObject(Token* token);
178 static AsyncRuntimeObject* ToAsyncRuntimeObject(Value* value);
179 static AsyncRuntimeObject* ToAsyncRuntimeObject(Group* group);
180
runner()181 AsyncTaskRunner* runner() const { return runner_; }
182
183 private:
184 // Blocks the caller thread until awaitable async value becomes available.
185 void Await(tfrt::AsyncValue* awaitable);
186
187 AsyncTaskRunner* runner_; // must outlive *this
188 };
189
190 // A base class for all Async dialect types reference counted at runtime.
191 class AsyncRuntimeObject : public tfrt::ReferenceCounted<AsyncRuntimeObject> {
192 public:
193 using ReferenceCounted::ReferenceCounted; // inherit constructors
194 virtual ~AsyncRuntimeObject() = default;
195 };
196
197 template <typename F>
Execute(F && f)198 void AsyncRuntime::Execute(F&& f) {
199 runner_->Schedule(std::forward<F>(f));
200 }
201
202 template <typename F>
AwaitToken(Token * token,F && f)203 void AsyncRuntime::AwaitToken(Token* token, F&& f) {
204 AsyncRuntime::GetAsyncValue(token)->AndThen(std::forward<F>(f));
205 }
206
207 template <typename F>
AwaitValue(Value * value,F && f)208 void AsyncRuntime::AwaitValue(Value* value, F&& f) {
209 AsyncRuntime::GetAsyncValue(value)->AndThen(std::forward<F>(f));
210 }
211
212 template <typename F>
AwaitGroup(Group * group,F && f)213 void AsyncRuntime::AwaitGroup(Group* group, F&& f) {
214 AsyncRuntime::GetAsyncValue(group)->AndThen(std::forward<F>(f));
215 }
216
217 // Runs async tasks by enqueing them into the host context work queue.
218 class HostContextAsyncTaskRunner : public AsyncTaskRunner {
219 public:
HostContextAsyncTaskRunner(tfrt::HostContext * host)220 explicit HostContextAsyncTaskRunner(tfrt::HostContext* host) : host_(host) {}
Schedule(Task task)221 void Schedule(Task task) override { EnqueueWork(host_, std::move(task)); }
222
223 private:
224 tfrt::HostContext* host_;
225 };
226
227 // Runs async tasks by scheduling them into the Eigen thread pool.
228 class EigenThreadPoolAsyncTaskRunner : public AsyncTaskRunner {
229 public:
EigenThreadPoolAsyncTaskRunner(Eigen::ThreadPoolInterface * thread_pool)230 explicit EigenThreadPoolAsyncTaskRunner(
231 Eigen::ThreadPoolInterface* thread_pool)
232 : thread_pool_(thread_pool) {}
Schedule(Task task)233 void Schedule(Task task) override { thread_pool_->Schedule(std::move(task)); }
234
235 private:
236 Eigen::ThreadPoolInterface* thread_pool_;
237 };
238
239 } // namespace runtime
240 } // namespace xla
241
242 #endif // XLA_RUNTIME_ASYNC_RUNTIME_H_
243