xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
17 
18 #include <stdlib.h>
19 
20 #include <cstddef>
21 #include <iostream>
22 #include <ostream>
23 #include <thread>  // NOLINT TODO(ezhulenev): Remove this header.
24 #include <type_traits>
25 
26 #include "mlir/ExecutionEngine/AsyncRuntime.h"  // from @llvm-project
27 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
28 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
29 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
30 #include "tfrt/host_context/chain.h"  // from @tf_runtime
31 #include "tfrt/support/alloc.h"  // from @tf_runtime
32 #include "tfrt/support/msan.h"  // from @tf_runtime
33 
34 namespace xla {
35 namespace runtime {
36 
37 using tfrt::AlignedAlloc;
38 using tfrt::AlignedFree;
39 using tfrt::AsyncValue;
40 using tfrt::AsyncValueRef;
41 using tfrt::Chain;
42 
ConvertAsyncTokenToChain(AsyncRuntime::Token * token)43 AsyncValueRef<Chain> ConvertAsyncTokenToChain(AsyncRuntime::Token *token) {
44   auto *async_value = AsyncRuntime::GetAsyncValue(token);
45   auto out_chain = AsyncValueRef<Chain>(FormRef(async_value));
46   AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(token));
47   return out_chain;
48 }
49 
ExtractAsyncValue(AsyncRuntime::Value * value,AsyncValue * dst,llvm::function_ref<void (void * storage,AsyncValue * dst)> emplace_fn)50 void ExtractAsyncValue(
51     AsyncRuntime::Value *value, AsyncValue *dst,
52     llvm::function_ref<void(void *storage, AsyncValue *dst)> emplace_fn) {
53   auto *async_value = AsyncRuntime::GetAsyncValue(value);
54 
55   // Fast path if async value is already available.
56   if (async_value->IsAvailable()) {
57     void *storage = AsyncRuntime::GetStorage(value);
58     emplace_fn(storage, dst);
59     AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
60     return;
61   }
62 
63   // Wait for the async value completion, and emplace the `dst`.
64   async_value->AndThen([value, emplace_fn, dst = FormRef(dst)]() {
65     void *storage = AsyncRuntime::GetStorage(value);
66     emplace_fn(storage, dst.get());
67     AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
68   });
69 }
70 
ExtractAsyncValue(AsyncRuntime::Value * value,AsyncValue * dst,void * context,llvm::function_ref<void (void * storage,AsyncValue * dst,void * context)> emplace_fn)71 void ExtractAsyncValue(
72     AsyncRuntime::Value *value, AsyncValue *dst, void *context,
73     llvm::function_ref<void(void *storage, AsyncValue *dst, void *context)>
74         emplace_fn) {
75   auto *async_value = AsyncRuntime::GetAsyncValue(value);
76 
77   // Fast path if async value is already available.
78   if (async_value->IsAvailable()) {
79     void *storage = AsyncRuntime::GetStorage(value);
80     emplace_fn(storage, dst, context);
81     AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
82     return;
83   }
84 
85   // Wait for the async value completion, and emplace the `dst`.
86   async_value->AndThen([value, emplace_fn, context, dst = FormRef(dst)]() {
87     void *storage = AsyncRuntime::GetStorage(value);
88     emplace_fn(storage, dst.get(), context);
89     AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
90   });
91 }
92 
AsyncRuntimeApiSymbolMap(llvm::orc::MangleAndInterner mangle)93 llvm::orc::SymbolMap AsyncRuntimeApiSymbolMap(
94     llvm::orc::MangleAndInterner mangle) {
95   llvm::orc::SymbolMap symbol_map;
96 
97   auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
98     symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
99         llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
100   };
101 
102   bind("mlirAsyncRuntimeAddRef", &mlir::runtime::mlirAsyncRuntimeAddRef);
103   bind("mlirAsyncRuntimeDropRef", &mlir::runtime::mlirAsyncRuntimeDropRef);
104   bind("mlirAsyncRuntimeExecute", &mlir::runtime::mlirAsyncRuntimeExecute);
105   bind("mlirAsyncRuntimeGetValueStorage",
106        &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
107   bind("mlirAsyncRuntimeCreateToken",
108        &mlir::runtime::mlirAsyncRuntimeCreateToken);
109   bind("mlirAsyncRuntimeCreateValue",
110        &mlir::runtime::mlirAsyncRuntimeCreateValue);
111   bind("mlirAsyncRuntimeEmplaceToken",
112        &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
113   bind("mlirAsyncRuntimeSetTokenError",
114        &mlir::runtime::mlirAsyncRuntimeSetTokenError);
115   bind("mlirAsyncRuntimeIsTokenError",
116        &mlir::runtime::mlirAsyncRuntimeIsTokenError);
117   bind("mlirAsyncRuntimeEmplaceValue",
118        &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
119   bind("mlirAsyncRuntimeSetValueError",
120        &mlir::runtime::mlirAsyncRuntimeSetValueError);
121   bind("mlirAsyncRuntimeIsValueError",
122        &mlir::runtime::mlirAsyncRuntimeIsValueError);
123   bind("mlirAsyncRuntimeIsGroupError",
124        &mlir::runtime::mlirAsyncRuntimeIsGroupError);
125   bind("mlirAsyncRuntimeAwaitToken",
126        &mlir::runtime::mlirAsyncRuntimeAwaitToken);
127   bind("mlirAsyncRuntimeAwaitValue",
128        &mlir::runtime::mlirAsyncRuntimeAwaitValue);
129   bind("mlirAsyncRuntimeAwaitTokenAndExecute",
130        &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
131   bind("mlirAsyncRuntimeAwaitValueAndExecute",
132        &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
133   bind("mlirAsyncRuntimeCreateGroup",
134        &mlir::runtime::mlirAsyncRuntimeCreateGroup);
135   bind("mlirAsyncRuntimeAddTokenToGroup",
136        &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
137   bind("mlirAsyncRuntimeAwaitAllInGroup",
138        &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
139   bind("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
140        &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
141   bind("mlirAsyncRuntimePrintCurrentThreadId",
142        &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
143 
144   return symbol_map;
145 }
146 
147 namespace {
148 
RuntimeAlignedAlloc(size_t alignment,size_t size)149 void *RuntimeAlignedAlloc(size_t alignment, size_t size) {
150   return AlignedAlloc(alignment, size);
151 }
152 
RuntimeMalloc(size_t size)153 void *RuntimeMalloc(size_t size) {
154   // AlignedAlloc() requires results to be deallocated with AlignedFree().
155   // Make all allocations aligned because there is only one RuntimeFree().
156   // Align to the size of a pointer by default.
157   return RuntimeAlignedAlloc(sizeof(void *), size);
158 }
159 
RuntimeFree(void * ptr)160 void RuntimeFree(void *ptr) { return AlignedFree(ptr); }
161 
162 }  // namespace
163 
AsyncRuntimeMemoryAllocationSymbolMap(llvm::orc::MangleAndInterner mangle)164 llvm::orc::SymbolMap AsyncRuntimeMemoryAllocationSymbolMap(
165     llvm::orc::MangleAndInterner mangle) {
166   llvm::orc::SymbolMap symbol_map;
167 
168   auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
169     symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
170         llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
171   };
172 
173   bind("malloc", &RuntimeMalloc);
174   bind("free", &RuntimeFree);
175   bind("aligned_alloc", &RuntimeAlignedAlloc);
176 
177   return symbol_map;
178 }
179 
180 }  // namespace runtime
181 }  // namespace xla
182 
183 //===----------------------------------------------------------------------===//
184 // MLIR Async runtime API.
185 //===----------------------------------------------------------------------===//
186 
187 // TODO(b/192775419): All pointers passed from the JIT compiled code to the
188 // runtime API must be marked initialized when running with msan enabled,
189 // because currently we do not have a way to enable sanitizer in the compiled
190 // code, and msan does not have any visibility into that code at runtime.
191 
192 namespace mlir {
193 namespace runtime {
194 
195 using xla::runtime::AsyncRuntime;
196 using xla::runtime::AsyncRuntimeObject;
197 
198 // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)199 void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
200   AsyncRuntimeObject *obj = static_cast<AsyncRuntimeObject *>(ptr);
201   TFRT_MSAN_MEMORY_IS_INITIALIZED(&ptr, sizeof(RefCountedObjPtr));
202   TFRT_MSAN_MEMORY_IS_INITIALIZED(&count, sizeof(int64_t));
203   AsyncRuntime::AddRef(obj, count);
204 }
205 
206 // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)207 void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
208   AsyncRuntimeObject *obj = static_cast<AsyncRuntimeObject *>(ptr);
209   TFRT_MSAN_MEMORY_IS_INITIALIZED(&ptr, sizeof(RefCountedObjPtr));
210   TFRT_MSAN_MEMORY_IS_INITIALIZED(&count, sizeof(int64_t));
211   AsyncRuntime::DropRef(obj, count);
212 }
213 
214 // Create a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()215 AsyncToken *mlirAsyncRuntimeCreateToken() {
216   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
217   return runtime.CreateToken();
218 }
219 
220 // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)221 AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
222   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
223   TFRT_MSAN_MEMORY_IS_INITIALIZED(&size, sizeof(int64_t));
224   return runtime.CreateValue(size, /*alignment=*/alignof(std::max_align_t));
225 }
226 
227 // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)228 AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
229   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
230   TFRT_MSAN_MEMORY_IS_INITIALIZED(&size, sizeof(int64_t));
231   return runtime.CreateGroup(size);
232 }
233 
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)234 int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
235   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
236   TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
237   TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
238   return runtime.AddTokenToGroup(group, token);
239 }
240 
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)241 bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
242   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
243   TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
244   return runtime.IsError(group);
245 }
246 
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)247 void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
248   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
249   TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
250   runtime.SetAvailable(token);
251 }
252 
mlirAsyncRuntimeSetTokenError(AsyncToken * token)253 void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
254   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
255   TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
256   runtime.SetError(token);
257 }
258 
mlirAsyncRuntimeIsTokenError(AsyncToken * token)259 bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
260   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
261   TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
262   return runtime.IsError(token);
263 }
264 
mlirAsyncRuntimeAwaitToken(AsyncToken * token)265 void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
266   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
267   TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
268   runtime.AwaitToken(token);
269 }
270 
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)271 void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
272   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
273   TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
274   runtime.AwaitGroup(group);
275 }
276 
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)277 ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
278   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
279   TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
280   return runtime.GetStorage(value);
281 }
282 
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)283 void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
284   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
285   TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
286   runtime.SetAvailable(value);
287 }
288 
mlirAsyncRuntimeSetValueError(AsyncValue * value)289 void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
290   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
291   TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
292   runtime.SetError(value);
293 }
294 
mlirAsyncRuntimeIsValueError(AsyncValue * value)295 bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
296   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
297   TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
298   return runtime.IsError(value);
299 }
300 
mlirAsyncRuntimeAwaitValue(AsyncValue * value)301 void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
302   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
303   TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
304   runtime.AwaitValue(value);
305 }
306 
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)307 void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
308   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
309   runtime.Execute([resume, handle, runtime]() {
310     AsyncRuntime::Set(runtime);
311     (*resume)(handle);
312   });
313 }
314 
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)315 void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle,
316                                           CoroResume resume) {
317   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
318   TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
319   runtime.AwaitToken(token, [handle, resume, runtime]() {
320     AsyncRuntime::Set(runtime);
321     (*resume)(handle);
322   });
323 }
324 
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)325 void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, CoroHandle handle,
326                                           CoroResume resume) {
327   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
328   TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
329   runtime.AwaitValue(value, [handle, resume, runtime]() {
330     AsyncRuntime::Set(runtime);
331 
332     (*resume)(handle);
333   });
334 }
335 
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)336 void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
337                                                CoroHandle handle,
338                                                CoroResume resume) {
339   AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
340   TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
341   runtime.AwaitGroup(group, [handle, resume, runtime]() {
342     AsyncRuntime::Set(runtime);
343     (*resume)(handle);
344   });
345 }
346 
347 //===----------------------------------------------------------------------===//
348 // Small async runtime support library for testing.
349 //===----------------------------------------------------------------------===//
350 
mlirAsyncRuntimePrintCurrentThreadId()351 void mlirAsyncRuntimePrintCurrentThreadId() {
352   static thread_local std::thread::id thisId = std::this_thread::get_id();
353   std::cout << "Current thread id: " << thisId << std::endl;
354 }
355 
356 }  // namespace runtime
357 }  // namespace mlir
358