1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h"
17
18 #include <stdlib.h>
19
20 #include <cstddef>
21 #include <iostream>
22 #include <ostream>
23 #include <thread> // NOLINT TODO(ezhulenev): Remove this header.
24 #include <type_traits>
25
26 #include "mlir/ExecutionEngine/AsyncRuntime.h" // from @llvm-project
27 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
28 #include "tfrt/host_context/async_value.h" // from @tf_runtime
29 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
30 #include "tfrt/host_context/chain.h" // from @tf_runtime
31 #include "tfrt/support/alloc.h" // from @tf_runtime
32 #include "tfrt/support/msan.h" // from @tf_runtime
33
34 namespace xla {
35 namespace runtime {
36
37 using tfrt::AlignedAlloc;
38 using tfrt::AlignedFree;
39 using tfrt::AsyncValue;
40 using tfrt::AsyncValueRef;
41 using tfrt::Chain;
42
ConvertAsyncTokenToChain(AsyncRuntime::Token * token)43 AsyncValueRef<Chain> ConvertAsyncTokenToChain(AsyncRuntime::Token *token) {
44 auto *async_value = AsyncRuntime::GetAsyncValue(token);
45 auto out_chain = AsyncValueRef<Chain>(FormRef(async_value));
46 AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(token));
47 return out_chain;
48 }
49
ExtractAsyncValue(AsyncRuntime::Value * value,AsyncValue * dst,llvm::function_ref<void (void * storage,AsyncValue * dst)> emplace_fn)50 void ExtractAsyncValue(
51 AsyncRuntime::Value *value, AsyncValue *dst,
52 llvm::function_ref<void(void *storage, AsyncValue *dst)> emplace_fn) {
53 auto *async_value = AsyncRuntime::GetAsyncValue(value);
54
55 // Fast path if async value is already available.
56 if (async_value->IsAvailable()) {
57 void *storage = AsyncRuntime::GetStorage(value);
58 emplace_fn(storage, dst);
59 AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
60 return;
61 }
62
63 // Wait for the async value completion, and emplace the `dst`.
64 async_value->AndThen([value, emplace_fn, dst = FormRef(dst)]() {
65 void *storage = AsyncRuntime::GetStorage(value);
66 emplace_fn(storage, dst.get());
67 AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
68 });
69 }
70
ExtractAsyncValue(AsyncRuntime::Value * value,AsyncValue * dst,void * context,llvm::function_ref<void (void * storage,AsyncValue * dst,void * context)> emplace_fn)71 void ExtractAsyncValue(
72 AsyncRuntime::Value *value, AsyncValue *dst, void *context,
73 llvm::function_ref<void(void *storage, AsyncValue *dst, void *context)>
74 emplace_fn) {
75 auto *async_value = AsyncRuntime::GetAsyncValue(value);
76
77 // Fast path if async value is already available.
78 if (async_value->IsAvailable()) {
79 void *storage = AsyncRuntime::GetStorage(value);
80 emplace_fn(storage, dst, context);
81 AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
82 return;
83 }
84
85 // Wait for the async value completion, and emplace the `dst`.
86 async_value->AndThen([value, emplace_fn, context, dst = FormRef(dst)]() {
87 void *storage = AsyncRuntime::GetStorage(value);
88 emplace_fn(storage, dst.get(), context);
89 AsyncRuntime::DropRef(AsyncRuntime::ToAsyncRuntimeObject(value));
90 });
91 }
92
AsyncRuntimeApiSymbolMap(llvm::orc::MangleAndInterner mangle)93 llvm::orc::SymbolMap AsyncRuntimeApiSymbolMap(
94 llvm::orc::MangleAndInterner mangle) {
95 llvm::orc::SymbolMap symbol_map;
96
97 auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
98 symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
99 llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
100 };
101
102 bind("mlirAsyncRuntimeAddRef", &mlir::runtime::mlirAsyncRuntimeAddRef);
103 bind("mlirAsyncRuntimeDropRef", &mlir::runtime::mlirAsyncRuntimeDropRef);
104 bind("mlirAsyncRuntimeExecute", &mlir::runtime::mlirAsyncRuntimeExecute);
105 bind("mlirAsyncRuntimeGetValueStorage",
106 &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
107 bind("mlirAsyncRuntimeCreateToken",
108 &mlir::runtime::mlirAsyncRuntimeCreateToken);
109 bind("mlirAsyncRuntimeCreateValue",
110 &mlir::runtime::mlirAsyncRuntimeCreateValue);
111 bind("mlirAsyncRuntimeEmplaceToken",
112 &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
113 bind("mlirAsyncRuntimeSetTokenError",
114 &mlir::runtime::mlirAsyncRuntimeSetTokenError);
115 bind("mlirAsyncRuntimeIsTokenError",
116 &mlir::runtime::mlirAsyncRuntimeIsTokenError);
117 bind("mlirAsyncRuntimeEmplaceValue",
118 &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
119 bind("mlirAsyncRuntimeSetValueError",
120 &mlir::runtime::mlirAsyncRuntimeSetValueError);
121 bind("mlirAsyncRuntimeIsValueError",
122 &mlir::runtime::mlirAsyncRuntimeIsValueError);
123 bind("mlirAsyncRuntimeIsGroupError",
124 &mlir::runtime::mlirAsyncRuntimeIsGroupError);
125 bind("mlirAsyncRuntimeAwaitToken",
126 &mlir::runtime::mlirAsyncRuntimeAwaitToken);
127 bind("mlirAsyncRuntimeAwaitValue",
128 &mlir::runtime::mlirAsyncRuntimeAwaitValue);
129 bind("mlirAsyncRuntimeAwaitTokenAndExecute",
130 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
131 bind("mlirAsyncRuntimeAwaitValueAndExecute",
132 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
133 bind("mlirAsyncRuntimeCreateGroup",
134 &mlir::runtime::mlirAsyncRuntimeCreateGroup);
135 bind("mlirAsyncRuntimeAddTokenToGroup",
136 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
137 bind("mlirAsyncRuntimeAwaitAllInGroup",
138 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
139 bind("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
140 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
141 bind("mlirAsyncRuntimePrintCurrentThreadId",
142 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
143
144 return symbol_map;
145 }
146
147 namespace {
148
RuntimeAlignedAlloc(size_t alignment,size_t size)149 void *RuntimeAlignedAlloc(size_t alignment, size_t size) {
150 return AlignedAlloc(alignment, size);
151 }
152
RuntimeMalloc(size_t size)153 void *RuntimeMalloc(size_t size) {
154 // AlignedAlloc() requires results to be deallocated with AlignedFree().
155 // Make all allocations aligned because there is only one RuntimeFree().
156 // Align to the size of a pointer by default.
157 return RuntimeAlignedAlloc(sizeof(void *), size);
158 }
159
RuntimeFree(void * ptr)160 void RuntimeFree(void *ptr) { return AlignedFree(ptr); }
161
162 } // namespace
163
AsyncRuntimeMemoryAllocationSymbolMap(llvm::orc::MangleAndInterner mangle)164 llvm::orc::SymbolMap AsyncRuntimeMemoryAllocationSymbolMap(
165 llvm::orc::MangleAndInterner mangle) {
166 llvm::orc::SymbolMap symbol_map;
167
168 auto bind = [&](llvm::StringRef name, auto symbol_ptr) {
169 symbol_map[mangle(name)] = llvm::JITEvaluatedSymbol(
170 llvm::pointerToJITTargetAddress(symbol_ptr), llvm::JITSymbolFlags());
171 };
172
173 bind("malloc", &RuntimeMalloc);
174 bind("free", &RuntimeFree);
175 bind("aligned_alloc", &RuntimeAlignedAlloc);
176
177 return symbol_map;
178 }
179
180 } // namespace runtime
181 } // namespace xla
182
183 //===----------------------------------------------------------------------===//
184 // MLIR Async runtime API.
185 //===----------------------------------------------------------------------===//
186
187 // TODO(b/192775419): All pointers passed from the JIT compiled code to the
188 // runtime API must be marked initialized when running with msan enabled,
189 // because currently we do not have a way to enable sanitizer in the compiled
190 // code, and msan does not have any visibility into that code at runtime.
191
192 namespace mlir {
193 namespace runtime {
194
195 using xla::runtime::AsyncRuntime;
196 using xla::runtime::AsyncRuntimeObject;
197
198 // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)199 void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
200 AsyncRuntimeObject *obj = static_cast<AsyncRuntimeObject *>(ptr);
201 TFRT_MSAN_MEMORY_IS_INITIALIZED(&ptr, sizeof(RefCountedObjPtr));
202 TFRT_MSAN_MEMORY_IS_INITIALIZED(&count, sizeof(int64_t));
203 AsyncRuntime::AddRef(obj, count);
204 }
205
206 // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)207 void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
208 AsyncRuntimeObject *obj = static_cast<AsyncRuntimeObject *>(ptr);
209 TFRT_MSAN_MEMORY_IS_INITIALIZED(&ptr, sizeof(RefCountedObjPtr));
210 TFRT_MSAN_MEMORY_IS_INITIALIZED(&count, sizeof(int64_t));
211 AsyncRuntime::DropRef(obj, count);
212 }
213
214 // Create a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()215 AsyncToken *mlirAsyncRuntimeCreateToken() {
216 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
217 return runtime.CreateToken();
218 }
219
220 // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)221 AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
222 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
223 TFRT_MSAN_MEMORY_IS_INITIALIZED(&size, sizeof(int64_t));
224 return runtime.CreateValue(size, /*alignment=*/alignof(std::max_align_t));
225 }
226
227 // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)228 AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
229 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
230 TFRT_MSAN_MEMORY_IS_INITIALIZED(&size, sizeof(int64_t));
231 return runtime.CreateGroup(size);
232 }
233
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)234 int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token, AsyncGroup *group) {
235 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
236 TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
237 TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
238 return runtime.AddTokenToGroup(group, token);
239 }
240
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)241 bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
242 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
243 TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
244 return runtime.IsError(group);
245 }
246
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)247 void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
248 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
249 TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
250 runtime.SetAvailable(token);
251 }
252
mlirAsyncRuntimeSetTokenError(AsyncToken * token)253 void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
254 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
255 TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
256 runtime.SetError(token);
257 }
258
mlirAsyncRuntimeIsTokenError(AsyncToken * token)259 bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
260 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
261 TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
262 return runtime.IsError(token);
263 }
264
mlirAsyncRuntimeAwaitToken(AsyncToken * token)265 void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
266 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
267 TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
268 runtime.AwaitToken(token);
269 }
270
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)271 void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
272 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
273 TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
274 runtime.AwaitGroup(group);
275 }
276
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)277 ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
278 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
279 TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
280 return runtime.GetStorage(value);
281 }
282
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)283 void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
284 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
285 TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
286 runtime.SetAvailable(value);
287 }
288
mlirAsyncRuntimeSetValueError(AsyncValue * value)289 void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
290 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
291 TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
292 runtime.SetError(value);
293 }
294
mlirAsyncRuntimeIsValueError(AsyncValue * value)295 bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
296 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
297 TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
298 return runtime.IsError(value);
299 }
300
mlirAsyncRuntimeAwaitValue(AsyncValue * value)301 void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
302 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
303 TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
304 runtime.AwaitValue(value);
305 }
306
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)307 void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
308 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
309 runtime.Execute([resume, handle, runtime]() {
310 AsyncRuntime::Set(runtime);
311 (*resume)(handle);
312 });
313 }
314
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)315 void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token, CoroHandle handle,
316 CoroResume resume) {
317 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
318 TFRT_MSAN_MEMORY_IS_INITIALIZED(&token, sizeof(void *));
319 runtime.AwaitToken(token, [handle, resume, runtime]() {
320 AsyncRuntime::Set(runtime);
321 (*resume)(handle);
322 });
323 }
324
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)325 void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value, CoroHandle handle,
326 CoroResume resume) {
327 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
328 TFRT_MSAN_MEMORY_IS_INITIALIZED(&value, sizeof(void *));
329 runtime.AwaitValue(value, [handle, resume, runtime]() {
330 AsyncRuntime::Set(runtime);
331
332 (*resume)(handle);
333 });
334 }
335
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)336 void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
337 CoroHandle handle,
338 CoroResume resume) {
339 AsyncRuntime &runtime = AsyncRuntime::GetCurrentRuntime();
340 TFRT_MSAN_MEMORY_IS_INITIALIZED(&group, sizeof(void *));
341 runtime.AwaitGroup(group, [handle, resume, runtime]() {
342 AsyncRuntime::Set(runtime);
343 (*resume)(handle);
344 });
345 }
346
347 //===----------------------------------------------------------------------===//
348 // Small async runtime support library for testing.
349 //===----------------------------------------------------------------------===//
350
mlirAsyncRuntimePrintCurrentThreadId()351 void mlirAsyncRuntimePrintCurrentThreadId() {
352 static thread_local std::thread::id thisId = std::this_thread::get_id();
353 std::cout << "Current thread id: " << thisId << std::endl;
354 }
355
356 } // namespace runtime
357 } // namespace mlir
358