xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/utils/runtime/async_runtime_api.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_UTILS_RUNTIME_ASYNC_RUNTIME_API_H_
17 #define XLA_MLIR_UTILS_RUNTIME_ASYNC_RUNTIME_API_H_
18 
19 #include "llvm/ExecutionEngine/Orc/Core.h"
20 #include "llvm/ExecutionEngine/Orc/Mangling.h"
21 #include "tensorflow/compiler/xla/runtime/async_runtime.h"
22 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
23 
24 namespace tfrt {
25 class Chain;
26 }  // namespace tfrt
27 
28 namespace xla {
29 namespace runtime {
30 
31 // Converts MLIR Async Runtime token into the TFRT async chain, and drops the
32 // reference count on the token.
33 tfrt::AsyncValueRef<tfrt::Chain> ConvertAsyncTokenToChain(
34     AsyncRuntime::Token* token);
35 
36 // Extracts a payload from the MLIR Async Runtime `value` and emplaces it into
37 // the TFRT async value `dst` using a user provided emplace function. Drops the
38 // reference on the runtime value after it is no longer needed.
39 void ExtractAsyncValue(
40     AsyncRuntime::Value* value, tfrt::AsyncValue* dst,
41     llvm::function_ref<void(void*, tfrt::AsyncValue*)> emplace_fn);
42 
43 // A version of the `ExtractAsyncValue` function defined above that takes an
44 // additional opaque pointer that will be passed to the emplace function when
45 // async value will become ready. It is the caller responsibility to ensure that
46 // the pointed object will stay alive.
47 void ExtractAsyncValue(
48     AsyncRuntime::Value* value, tfrt::AsyncValue* dst, void* context,
49     llvm::function_ref<void(void*, tfrt::AsyncValue*, void*)> emplace_fn);
50 
51 // Builds a symbol map from the Async Runtime API functions.
52 llvm::orc::SymbolMap AsyncRuntimeApiSymbolMap(
53     llvm::orc::MangleAndInterner mangle);
54 
55 // TODO(ezhulenev): This should not be a part of async runtime api library.
56 llvm::orc::SymbolMap AsyncRuntimeMemoryAllocationSymbolMap(
57     llvm::orc::MangleAndInterner mangle);
58 
59 }  // namespace runtime
60 }  // namespace xla
61 
62 #endif  // XLA_MLIR_UTILS_RUNTIME_ASYNC_RUNTIME_API_H_
63