xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/async_runtime.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
17 
18 #include <atomic>
19 #include <cstddef>
20 #include <cstdlib>
21 #include <memory>
22 #include <type_traits>
23 #include <utility>
24 
25 #include "absl/base/dynamic_annotations.h"
26 #include "llvm/Support/MathExtras.h"
27 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
28 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
29 #include "tfrt/host_context/chain.h"  // from @tf_runtime
30 #include "tfrt/host_context/diagnostic.h"  // from @tf_runtime
31 #include "tfrt/support/alloc.h"  // from @tf_runtime
32 #include "tfrt/support/ref_count.h"  // from @tf_runtime
33 
34 // -------------------------------------------------------------------------- //
35 // Define AsyncToken and AsyncGroup in the mlir::runtime namespace to implement
36 // opaque structs defined in the MLIR Async Runtime API header file.
37 // -------------------------------------------------------------------------- //
38 
39 namespace mlir {
40 namespace runtime {
41 
42 using tfrt::AlignedAlloc;
43 using tfrt::AlignedFree;
44 using tfrt::AsyncValueRef;
45 using tfrt::Chain;
46 using tfrt::GetReadyChain;
47 using tfrt::MakeConstructedAsyncValueRef;
48 
49 using xla::runtime::AsyncRuntimeObject;
50 
51 class AsyncToken : public AsyncRuntimeObject {
52  public:
AsyncToken(unsigned ref_count=1)53   explicit AsyncToken(unsigned ref_count = 1)
54       : AsyncRuntimeObject(ref_count),
55         chain_(MakeConstructedAsyncValueRef<Chain>()) {}
56 
GetAsyncValue() const57   tfrt::AsyncValue* GetAsyncValue() const { return chain_.GetAsyncValue(); }
58 
59  private:
60   AsyncValueRef<Chain> chain_;
61 };
62 
63 class AsyncValue : public AsyncRuntimeObject {
64  public:
AsyncValue(size_t size,size_t alignment,unsigned ref_count=1)65   explicit AsyncValue(size_t size, size_t alignment, unsigned ref_count = 1)
66       : AsyncRuntimeObject(ref_count),
67         storage_(MakeConstructedAsyncValueRef<Storage>(size, alignment)) {
68     // Storage memory will be initialized by the compiled kernel.
69     ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(GetStorage(), size);
70   }
71 
GetStorage() const72   void* GetStorage() const {
73     assert(!GetAsyncValue()->IsError() && "unexpected error state");
74     if (storage_->is_inline) return &storage_->inline_buffer;
75     return storage_->allocated_buffer;
76   }
77 
GetAsyncValue() const78   tfrt::AsyncValue* GetAsyncValue() const { return storage_.GetAsyncValue(); }
79 
80  private:
81   // If the requested async value storage is small, use the inlined storage.
82   // Fall back on dynamic allocation if the requested storage size is large.
83   struct Storage {
84     static const int kSize = 128;  // enough to fit memref descriptor of rank 5
85     static const int kAlign = alignof(std::max_align_t);
86 
Storagemlir::runtime::AsyncValue::Storage87     Storage(size_t size, size_t alignment)
88         : is_inline(CanStoreInline(size, alignment)) {
89       if (!is_inline) allocated_buffer = AlignedAlloc(alignment, size);
90     }
91 
~Storagemlir::runtime::AsyncValue::Storage92     ~Storage() {
93       if (!is_inline) AlignedFree(allocated_buffer);
94     }
95 
CanStoreInlinemlir::runtime::AsyncValue::Storage96     static bool CanStoreInline(size_t size, size_t alignment) {
97       assert(llvm::isPowerOf2_32(alignment));
98       return size <= kSize && alignment <= kAlign;
99     }
100 
101     bool is_inline;
102     union {
103       std::aligned_storage<kSize, kAlign>::type inline_buffer;
104       void* allocated_buffer;
105     };
106   };
107 
108   AsyncValueRef<Storage> storage_;
109 };
110 
111 class AsyncGroup : public AsyncRuntimeObject {
112  public:
AsyncGroup(int64_t size,unsigned ref_count=1)113   explicit AsyncGroup(int64_t size, unsigned ref_count = 1)
114       : AsyncRuntimeObject(ref_count),
115         size_(size),
116         rank_(0),
117         pending_tokens_(size),
118         num_errors_(0),
119         completed_(size == 0 ? GetReadyChain()
120                              : MakeConstructedAsyncValueRef<Chain>()) {
121     assert(size_ >= 0 && "size can't be negative");
122   }
123 
AddToken(AsyncToken * token)124   size_t AddToken(AsyncToken* token) {
125     size_t rank = rank_.fetch_add(1, std::memory_order_relaxed);
126     assert(rank < size_ && "can't add more tokens than the group size");
127 
128     // When token becomes available drop the number of pending tokens and maybe
129     // make the group completion async value available.
130     token->GetAsyncValue()->AndThen([group = this, token]() {
131       // Increment the number of errors in the group.
132       if (token->GetAsyncValue()->IsError()) group->num_errors_.fetch_add(1);
133 
134       // Pending tokens can't drop below zero.
135       assert(group->pending_tokens_ > 0 && "wrong group size");
136 
137       // We do track group error state with the number of errors, and never
138       // set completion async value state to error.
139       if (group->pending_tokens_.fetch_sub(1) == 1)
140         group->completed_.SetStateConcrete();
141     });
142 
143     return rank;
144   }
145 
GetCompletionAsyncValue() const146   tfrt::AsyncValue* GetCompletionAsyncValue() const {
147     return completed_.GetAsyncValue();
148   }
149 
IsError() const150   bool IsError() const { return num_errors_.load() != 0; }
151 
152  private:
153   int64_t size_;
154   std::atomic<int64_t> rank_;
155   std::atomic<int64_t> pending_tokens_;
156   std::atomic<int64_t> num_errors_;
157 
158   // Async value that keeps track the group completion, it will become available
159   // when the number of pending tokens will drop to zero.
160   AsyncValueRef<Chain> completed_;
161 };
162 
163 }  // namespace runtime
164 }  // namespace mlir
165 
166 // -------------------------------------------------------------------------- //
167 
168 namespace xla {
169 namespace runtime {
170 
171 using tfrt::AsyncValue;
172 using tfrt::DecodedDiagnostic;
173 
174 namespace {
175 // Always keep the current active async runtime in a thread local variable.
176 static thread_local AsyncRuntime async_runtime;
177 
178 static_assert(std::is_trivially_destructible<AsyncRuntime>::value,
179               "AsyncRuntime must be trivially destructible");
180 
181 static_assert(std::is_trivially_copy_assignable<AsyncRuntime>::value,
182               "AsyncRuntime must be trivially copy assignable");
183 
184 static_assert(std::is_trivially_copy_constructible<AsyncRuntime>::value,
185               "AsyncRuntime must be trivially copy constructible");
186 
187 // This is an arbitrary limitation, to make sure that AsyncRuntime would not
188 // become expensive to copy unnoticed.
189 static_assert(sizeof(AsyncRuntime) == 1 * sizeof(void*),
190               "AsyncRuntime must only hold one pointer");
191 
192 }  // namespace
193 
Set(AsyncRuntime runtime)194 /*static*/ void AsyncRuntime::Set(AsyncRuntime runtime) {
195   assert(runtime.runner() != nullptr);
196   async_runtime = runtime;
197 }
198 
GetCurrentRuntime()199 /*static*/ AsyncRuntime& AsyncRuntime::GetCurrentRuntime() {
200   assert(async_runtime.runner() != nullptr);
201   return async_runtime;
202 }
203 
GetStorage(Value * value)204 /*static*/ void* AsyncRuntime::GetStorage(Value* value) {
205   return value->GetStorage();
206 }
207 
GetAsyncValue(AsyncRuntime::Value * value)208 /*static*/ AsyncValue* AsyncRuntime::GetAsyncValue(AsyncRuntime::Value* value) {
209   return value->GetAsyncValue();
210 }
211 
GetAsyncValue(AsyncRuntime::Token * token)212 /*static*/ AsyncValue* AsyncRuntime::GetAsyncValue(AsyncRuntime::Token* token) {
213   return token->GetAsyncValue();
214 }
215 
GetAsyncValue(AsyncRuntime::Group * group)216 /*static*/ AsyncValue* AsyncRuntime::GetAsyncValue(AsyncRuntime::Group* group) {
217   return group->GetCompletionAsyncValue();
218 }
219 
Await(AsyncValue * awaitable)220 void AsyncRuntime::Await(AsyncValue* awaitable) {
221   // Short circuit the trivial case.
222   if (awaitable->IsAvailable()) return;
223   tfrt::Await({awaitable});
224 }
225 
AddRef(AsyncRuntimeObject * obj,unsigned count)226 /*static*/ void AsyncRuntime::AddRef(AsyncRuntimeObject* obj, unsigned count) {
227   assert(count == 1 && "AsyncRuntimeObject can add just one ref");
228   obj->AddRef();
229 }
230 
DropRef(AsyncRuntimeObject * obj,unsigned count)231 /*static*/ void AsyncRuntime::DropRef(AsyncRuntimeObject* obj, unsigned count) {
232   assert(count == 1 && "AsyncRuntimeObject can drop just one ref");
233   obj->DropRef();
234 }
235 
ToAsyncRuntimeObject(AsyncRuntime::Token * token)236 /*static*/ AsyncRuntimeObject* AsyncRuntime::ToAsyncRuntimeObject(
237     AsyncRuntime::Token* token) {
238   return static_cast<AsyncRuntimeObject*>(token);
239 }
240 
ToAsyncRuntimeObject(AsyncRuntime::Value * value)241 /*static*/ AsyncRuntimeObject* AsyncRuntime::ToAsyncRuntimeObject(
242     AsyncRuntime::Value* value) {
243   return static_cast<AsyncRuntimeObject*>(value);
244 }
245 
ToAsyncRuntimeObject(AsyncRuntime::Group * group)246 /*static*/ AsyncRuntimeObject* AsyncRuntime::ToAsyncRuntimeObject(
247     AsyncRuntime::Group* group) {
248   return static_cast<AsyncRuntimeObject*>(group);
249 }
250 
CreateToken()251 AsyncRuntime::Token* AsyncRuntime::CreateToken() {
252   // AsyncRuntime::Token created with a reference count of 2 because it will be
253   // returned to the `async.execute` caller and also will be later on emplaced
254   // by the asynchronously executed task. If the caller immediately will drop
255   // its reference we must ensure that the token will be alive until the
256   // asynchronous operation is completed.
257   return new AsyncRuntime::Token(/*ref_count=*/2);
258 }
259 
SetAvailable(AsyncRuntime::Token * token)260 void AsyncRuntime::SetAvailable(AsyncRuntime::Token* token) {
261   token->GetAsyncValue()->SetStateConcrete();
262   // Async tokens created with a ref count `2` to keep token alive until the
263   // async task completes. Drop extra reference explicitly when token emplaced.
264   DropRef(token);
265 }
266 
SetError(AsyncRuntime::Token * token)267 void AsyncRuntime::SetError(AsyncRuntime::Token* token) {
268   // TODO(ezhulenev): Construct a better diagnostincs when async runtime API
269   // will support passing custom error messages.
270   token->GetAsyncValue()->SetError(DecodedDiagnostic("<async runtime error>"));
271   // Async tokens created with a ref count `2` to keep token alive until the
272   // async task completes. Drop extra reference explicitly when token emplaced.
273   DropRef(token);
274 }
275 
IsError(AsyncRuntime::Token * token)276 bool AsyncRuntime::IsError(AsyncRuntime::Token* token) {
277   return token->GetAsyncValue()->IsError();
278 }
279 
AwaitToken(AsyncRuntime::Token * token)280 void AsyncRuntime::AwaitToken(AsyncRuntime::Token* token) {
281   Await(token->GetAsyncValue());
282 }
283 
CreateValue(size_t size,size_t alignment)284 AsyncRuntime::Value* AsyncRuntime::CreateValue(size_t size, size_t alignment) {
285   // AsyncRuntime::Value created with a reference count of 2 because it will be
286   // returned to the `async.execute` caller and also will be later on emplaced
287   // by the asynchronously executed task. If the caller immediately will drop
288   // its reference we must ensure that the token will be alive until the
289   // asynchronous operation is completed.
290   return new AsyncRuntime::Value(size, alignment, /*ref_count=*/2);
291 }
292 
SetAvailable(AsyncRuntime::Value * value)293 void AsyncRuntime::SetAvailable(AsyncRuntime::Value* value) {
294   value->GetAsyncValue()->SetStateConcrete();
295   // Async values created with a ref count `2` to keep token alive until the
296   // async task completes. Drop extra reference explicitly when token emplaced.
297   DropRef(value);
298 }
299 
SetError(AsyncRuntime::Value * value)300 void AsyncRuntime::SetError(AsyncRuntime::Value* value) {
301   // TODO(ezhulenev): Construct a better diagnostincs when async runtime API
302   // will support passing custom error messages.
303   value->GetAsyncValue()->SetError(DecodedDiagnostic("<async runtime error>"));
304   // Async values created with a ref count `2` to keep token alive until the
305   // async task completes. Drop extra reference explicitly when token emplaced.
306   DropRef(value);
307 }
308 
IsError(AsyncRuntime::Value * value)309 bool AsyncRuntime::IsError(AsyncRuntime::Value* value) {
310   return value->GetAsyncValue()->IsError();
311 }
312 
AwaitValue(AsyncRuntime::Value * value)313 void AsyncRuntime::AwaitValue(AsyncRuntime::Value* value) {
314   Await(value->GetAsyncValue());
315 }
316 
CreateGroup(int64_t size)317 AsyncRuntime::Group* AsyncRuntime::CreateGroup(int64_t size) {
318   return new AsyncRuntime::Group(size);
319 }
320 
AddTokenToGroup(AsyncRuntime::Group * group,AsyncRuntime::Token * token)321 size_t AsyncRuntime::AddTokenToGroup(AsyncRuntime::Group* group,
322                                      AsyncRuntime::Token* token) {
323   return group->AddToken(token);
324 }
325 
IsError(AsyncRuntime::Group * group)326 bool AsyncRuntime::IsError(AsyncRuntime::Group* group) {
327   return group->IsError();
328 }
329 
AwaitGroup(AsyncRuntime::Group * group)330 void AsyncRuntime::AwaitGroup(AsyncRuntime::Group* group) {
331   Await(group->GetCompletionAsyncValue());
332 }
333 
334 }  // namespace runtime
335 }  // namespace xla
336