1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
17
18 #include <atomic>
19 #include <cstddef>
20 #include <cstdlib>
21 #include <memory>
22 #include <type_traits>
23 #include <utility>
24
25 #include "absl/base/dynamic_annotations.h"
26 #include "llvm/Support/MathExtras.h"
27 #include "tfrt/host_context/async_value.h" // from @tf_runtime
28 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
29 #include "tfrt/host_context/chain.h" // from @tf_runtime
30 #include "tfrt/host_context/diagnostic.h" // from @tf_runtime
31 #include "tfrt/support/alloc.h" // from @tf_runtime
32 #include "tfrt/support/ref_count.h" // from @tf_runtime
33
34 // -------------------------------------------------------------------------- //
35 // Define AsyncToken and AsyncGroup in the mlir::runtime namespace to implement
36 // opaque structs defined in the MLIR Async Runtime API header file.
37 // -------------------------------------------------------------------------- //
38
39 namespace mlir {
40 namespace runtime {
41
42 using tfrt::AlignedAlloc;
43 using tfrt::AlignedFree;
44 using tfrt::AsyncValueRef;
45 using tfrt::Chain;
46 using tfrt::GetReadyChain;
47 using tfrt::MakeConstructedAsyncValueRef;
48
49 using xla::runtime::AsyncRuntimeObject;
50
51 class AsyncToken : public AsyncRuntimeObject {
52 public:
AsyncToken(unsigned ref_count=1)53 explicit AsyncToken(unsigned ref_count = 1)
54 : AsyncRuntimeObject(ref_count),
55 chain_(MakeConstructedAsyncValueRef<Chain>()) {}
56
GetAsyncValue() const57 tfrt::AsyncValue* GetAsyncValue() const { return chain_.GetAsyncValue(); }
58
59 private:
60 AsyncValueRef<Chain> chain_;
61 };
62
63 class AsyncValue : public AsyncRuntimeObject {
64 public:
AsyncValue(size_t size,size_t alignment,unsigned ref_count=1)65 explicit AsyncValue(size_t size, size_t alignment, unsigned ref_count = 1)
66 : AsyncRuntimeObject(ref_count),
67 storage_(MakeConstructedAsyncValueRef<Storage>(size, alignment)) {
68 // Storage memory will be initialized by the compiled kernel.
69 ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(GetStorage(), size);
70 }
71
GetStorage() const72 void* GetStorage() const {
73 assert(!GetAsyncValue()->IsError() && "unexpected error state");
74 if (storage_->is_inline) return &storage_->inline_buffer;
75 return storage_->allocated_buffer;
76 }
77
GetAsyncValue() const78 tfrt::AsyncValue* GetAsyncValue() const { return storage_.GetAsyncValue(); }
79
80 private:
81 // If the requested async value storage is small, use the inlined storage.
82 // Fall back on dynamic allocation if the requested storage size is large.
83 struct Storage {
84 static const int kSize = 128; // enough to fit memref descriptor of rank 5
85 static const int kAlign = alignof(std::max_align_t);
86
Storagemlir::runtime::AsyncValue::Storage87 Storage(size_t size, size_t alignment)
88 : is_inline(CanStoreInline(size, alignment)) {
89 if (!is_inline) allocated_buffer = AlignedAlloc(alignment, size);
90 }
91
~Storagemlir::runtime::AsyncValue::Storage92 ~Storage() {
93 if (!is_inline) AlignedFree(allocated_buffer);
94 }
95
CanStoreInlinemlir::runtime::AsyncValue::Storage96 static bool CanStoreInline(size_t size, size_t alignment) {
97 assert(llvm::isPowerOf2_32(alignment));
98 return size <= kSize && alignment <= kAlign;
99 }
100
101 bool is_inline;
102 union {
103 std::aligned_storage<kSize, kAlign>::type inline_buffer;
104 void* allocated_buffer;
105 };
106 };
107
108 AsyncValueRef<Storage> storage_;
109 };
110
111 class AsyncGroup : public AsyncRuntimeObject {
112 public:
AsyncGroup(int64_t size,unsigned ref_count=1)113 explicit AsyncGroup(int64_t size, unsigned ref_count = 1)
114 : AsyncRuntimeObject(ref_count),
115 size_(size),
116 rank_(0),
117 pending_tokens_(size),
118 num_errors_(0),
119 completed_(size == 0 ? GetReadyChain()
120 : MakeConstructedAsyncValueRef<Chain>()) {
121 assert(size_ >= 0 && "size can't be negative");
122 }
123
AddToken(AsyncToken * token)124 size_t AddToken(AsyncToken* token) {
125 size_t rank = rank_.fetch_add(1, std::memory_order_relaxed);
126 assert(rank < size_ && "can't add more tokens than the group size");
127
128 // When token becomes available drop the number of pending tokens and maybe
129 // make the group completion async value available.
130 token->GetAsyncValue()->AndThen([group = this, token]() {
131 // Increment the number of errors in the group.
132 if (token->GetAsyncValue()->IsError()) group->num_errors_.fetch_add(1);
133
134 // Pending tokens can't drop below zero.
135 assert(group->pending_tokens_ > 0 && "wrong group size");
136
137 // We do track group error state with the number of errors, and never
138 // set completion async value state to error.
139 if (group->pending_tokens_.fetch_sub(1) == 1)
140 group->completed_.SetStateConcrete();
141 });
142
143 return rank;
144 }
145
GetCompletionAsyncValue() const146 tfrt::AsyncValue* GetCompletionAsyncValue() const {
147 return completed_.GetAsyncValue();
148 }
149
IsError() const150 bool IsError() const { return num_errors_.load() != 0; }
151
152 private:
153 int64_t size_;
154 std::atomic<int64_t> rank_;
155 std::atomic<int64_t> pending_tokens_;
156 std::atomic<int64_t> num_errors_;
157
158 // Async value that keeps track the group completion, it will become available
159 // when the number of pending tokens will drop to zero.
160 AsyncValueRef<Chain> completed_;
161 };
162
163 } // namespace runtime
164 } // namespace mlir
165
166 // -------------------------------------------------------------------------- //
167
168 namespace xla {
169 namespace runtime {
170
171 using tfrt::AsyncValue;
172 using tfrt::DecodedDiagnostic;
173
174 namespace {
175 // Always keep the current active async runtime in a thread local variable.
176 static thread_local AsyncRuntime async_runtime;
177
178 static_assert(std::is_trivially_destructible<AsyncRuntime>::value,
179 "AsyncRuntime must be trivially destructible");
180
181 static_assert(std::is_trivially_copy_assignable<AsyncRuntime>::value,
182 "AsyncRuntime must be trivially copy assignable");
183
184 static_assert(std::is_trivially_copy_constructible<AsyncRuntime>::value,
185 "AsyncRuntime must be trivially copy constructible");
186
187 // This is an arbitrary limitation, to make sure that AsyncRuntime would not
188 // become expensive to copy unnoticed.
189 static_assert(sizeof(AsyncRuntime) == 1 * sizeof(void*),
190 "AsyncRuntime must only hold one pointer");
191
192 } // namespace
193
Set(AsyncRuntime runtime)194 /*static*/ void AsyncRuntime::Set(AsyncRuntime runtime) {
195 assert(runtime.runner() != nullptr);
196 async_runtime = runtime;
197 }
198
GetCurrentRuntime()199 /*static*/ AsyncRuntime& AsyncRuntime::GetCurrentRuntime() {
200 assert(async_runtime.runner() != nullptr);
201 return async_runtime;
202 }
203
GetStorage(Value * value)204 /*static*/ void* AsyncRuntime::GetStorage(Value* value) {
205 return value->GetStorage();
206 }
207
GetAsyncValue(AsyncRuntime::Value * value)208 /*static*/ AsyncValue* AsyncRuntime::GetAsyncValue(AsyncRuntime::Value* value) {
209 return value->GetAsyncValue();
210 }
211
GetAsyncValue(AsyncRuntime::Token * token)212 /*static*/ AsyncValue* AsyncRuntime::GetAsyncValue(AsyncRuntime::Token* token) {
213 return token->GetAsyncValue();
214 }
215
GetAsyncValue(AsyncRuntime::Group * group)216 /*static*/ AsyncValue* AsyncRuntime::GetAsyncValue(AsyncRuntime::Group* group) {
217 return group->GetCompletionAsyncValue();
218 }
219
Await(AsyncValue * awaitable)220 void AsyncRuntime::Await(AsyncValue* awaitable) {
221 // Short circuit the trivial case.
222 if (awaitable->IsAvailable()) return;
223 tfrt::Await({awaitable});
224 }
225
AddRef(AsyncRuntimeObject * obj,unsigned count)226 /*static*/ void AsyncRuntime::AddRef(AsyncRuntimeObject* obj, unsigned count) {
227 assert(count == 1 && "AsyncRuntimeObject can add just one ref");
228 obj->AddRef();
229 }
230
DropRef(AsyncRuntimeObject * obj,unsigned count)231 /*static*/ void AsyncRuntime::DropRef(AsyncRuntimeObject* obj, unsigned count) {
232 assert(count == 1 && "AsyncRuntimeObject can drop just one ref");
233 obj->DropRef();
234 }
235
ToAsyncRuntimeObject(AsyncRuntime::Token * token)236 /*static*/ AsyncRuntimeObject* AsyncRuntime::ToAsyncRuntimeObject(
237 AsyncRuntime::Token* token) {
238 return static_cast<AsyncRuntimeObject*>(token);
239 }
240
ToAsyncRuntimeObject(AsyncRuntime::Value * value)241 /*static*/ AsyncRuntimeObject* AsyncRuntime::ToAsyncRuntimeObject(
242 AsyncRuntime::Value* value) {
243 return static_cast<AsyncRuntimeObject*>(value);
244 }
245
ToAsyncRuntimeObject(AsyncRuntime::Group * group)246 /*static*/ AsyncRuntimeObject* AsyncRuntime::ToAsyncRuntimeObject(
247 AsyncRuntime::Group* group) {
248 return static_cast<AsyncRuntimeObject*>(group);
249 }
250
CreateToken()251 AsyncRuntime::Token* AsyncRuntime::CreateToken() {
252 // AsyncRuntime::Token created with a reference count of 2 because it will be
253 // returned to the `async.execute` caller and also will be later on emplaced
254 // by the asynchronously executed task. If the caller immediately will drop
255 // its reference we must ensure that the token will be alive until the
256 // asynchronous operation is completed.
257 return new AsyncRuntime::Token(/*ref_count=*/2);
258 }
259
SetAvailable(AsyncRuntime::Token * token)260 void AsyncRuntime::SetAvailable(AsyncRuntime::Token* token) {
261 token->GetAsyncValue()->SetStateConcrete();
262 // Async tokens created with a ref count `2` to keep token alive until the
263 // async task completes. Drop extra reference explicitly when token emplaced.
264 DropRef(token);
265 }
266
SetError(AsyncRuntime::Token * token)267 void AsyncRuntime::SetError(AsyncRuntime::Token* token) {
268 // TODO(ezhulenev): Construct a better diagnostincs when async runtime API
269 // will support passing custom error messages.
270 token->GetAsyncValue()->SetError(DecodedDiagnostic("<async runtime error>"));
271 // Async tokens created with a ref count `2` to keep token alive until the
272 // async task completes. Drop extra reference explicitly when token emplaced.
273 DropRef(token);
274 }
275
IsError(AsyncRuntime::Token * token)276 bool AsyncRuntime::IsError(AsyncRuntime::Token* token) {
277 return token->GetAsyncValue()->IsError();
278 }
279
AwaitToken(AsyncRuntime::Token * token)280 void AsyncRuntime::AwaitToken(AsyncRuntime::Token* token) {
281 Await(token->GetAsyncValue());
282 }
283
CreateValue(size_t size,size_t alignment)284 AsyncRuntime::Value* AsyncRuntime::CreateValue(size_t size, size_t alignment) {
285 // AsyncRuntime::Value created with a reference count of 2 because it will be
286 // returned to the `async.execute` caller and also will be later on emplaced
287 // by the asynchronously executed task. If the caller immediately will drop
288 // its reference we must ensure that the token will be alive until the
289 // asynchronous operation is completed.
290 return new AsyncRuntime::Value(size, alignment, /*ref_count=*/2);
291 }
292
SetAvailable(AsyncRuntime::Value * value)293 void AsyncRuntime::SetAvailable(AsyncRuntime::Value* value) {
294 value->GetAsyncValue()->SetStateConcrete();
295 // Async values created with a ref count `2` to keep token alive until the
296 // async task completes. Drop extra reference explicitly when token emplaced.
297 DropRef(value);
298 }
299
SetError(AsyncRuntime::Value * value)300 void AsyncRuntime::SetError(AsyncRuntime::Value* value) {
301 // TODO(ezhulenev): Construct a better diagnostincs when async runtime API
302 // will support passing custom error messages.
303 value->GetAsyncValue()->SetError(DecodedDiagnostic("<async runtime error>"));
304 // Async values created with a ref count `2` to keep token alive until the
305 // async task completes. Drop extra reference explicitly when token emplaced.
306 DropRef(value);
307 }
308
IsError(AsyncRuntime::Value * value)309 bool AsyncRuntime::IsError(AsyncRuntime::Value* value) {
310 return value->GetAsyncValue()->IsError();
311 }
312
AwaitValue(AsyncRuntime::Value * value)313 void AsyncRuntime::AwaitValue(AsyncRuntime::Value* value) {
314 Await(value->GetAsyncValue());
315 }
316
CreateGroup(int64_t size)317 AsyncRuntime::Group* AsyncRuntime::CreateGroup(int64_t size) {
318 return new AsyncRuntime::Group(size);
319 }
320
AddTokenToGroup(AsyncRuntime::Group * group,AsyncRuntime::Token * token)321 size_t AsyncRuntime::AddTokenToGroup(AsyncRuntime::Group* group,
322 AsyncRuntime::Token* token) {
323 return group->AddToken(token);
324 }
325
IsError(AsyncRuntime::Group * group)326 bool AsyncRuntime::IsError(AsyncRuntime::Group* group) {
327 return group->IsError();
328 }
329
AwaitGroup(AsyncRuntime::Group * group)330 void AsyncRuntime::AwaitGroup(AsyncRuntime::Group* group) {
331 Await(group->GetCompletionAsyncValue());
332 }
333
334 } // namespace runtime
335 } // namespace xla
336