1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_UTILS_RUNTIME_ASYNC_RUNTIME_API_H_ 17 #define XLA_MLIR_UTILS_RUNTIME_ASYNC_RUNTIME_API_H_ 18 19 #include "llvm/ExecutionEngine/Orc/Core.h" 20 #include "llvm/ExecutionEngine/Orc/Mangling.h" 21 #include "tensorflow/compiler/xla/runtime/async_runtime.h" 22 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime 23 24 namespace tfrt { 25 class Chain; 26 } // namespace tfrt 27 28 namespace xla { 29 namespace runtime { 30 31 // Converts MLIR Async Runtime token into the TFRT async chain, and drops the 32 // reference count on the token. 33 tfrt::AsyncValueRef<tfrt::Chain> ConvertAsyncTokenToChain( 34 AsyncRuntime::Token* token); 35 36 // Extracts a payload from the MLIR Async Runtime `value` and emplaces it into 37 // the TFRT async value `dst` using a user provided emplace function. Drops the 38 // reference on the runtime value after it is no longer needed. 39 void ExtractAsyncValue( 40 AsyncRuntime::Value* value, tfrt::AsyncValue* dst, 41 llvm::function_ref<void(void*, tfrt::AsyncValue*)> emplace_fn); 42 43 // A version of the `ExtractAsyncValue` function defined above that takes an 44 // additional opaque pointer that will be passed to the emplace function when 45 // async value will become ready. It is the caller responsibility to ensure that 46 // the pointed object will stay alive. 47 void ExtractAsyncValue( 48 AsyncRuntime::Value* value, tfrt::AsyncValue* dst, void* context, 49 llvm::function_ref<void(void*, tfrt::AsyncValue*, void*)> emplace_fn); 50 51 // Builds a symbol map from the Async Runtime API functions. 52 llvm::orc::SymbolMap AsyncRuntimeApiSymbolMap( 53 llvm::orc::MangleAndInterner mangle); 54 55 // TODO(ezhulenev): This should not be a part of async runtime api library. 56 llvm::orc::SymbolMap AsyncRuntimeMemoryAllocationSymbolMap( 57 llvm::orc::MangleAndInterner mangle); 58 59 } // namespace runtime 60 } // namespace xla 61 62 #endif // XLA_MLIR_UTILS_RUNTIME_ASYNC_RUNTIME_API_H_ 63