xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/mlir/transforms/runtime/calling_convention.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_MLIR_TRANSFORMS_RUNTIME_CALLING_CONVENTION_H_
17 #define XLA_MLIR_TRANSFORMS_RUNTIME_CALLING_CONVENTION_H_
18 
19 #include <functional>
20 
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
23 
24 namespace xla {
25 namespace runtime {
26 
27 // Calling convention converts the XLA entrypoint function type to the function
28 // type with a well defined ABI (e.g. tensors do not have an ABI, and must be
29 // passed across the function boundary as memrefs). In a nutshell it tells the
30 // XLA runtime how to call the compiled executable at run time, and how to
31 // return results back to the caller.
32 //
33 // All types in the converted function signature should have a registered
34 // type conversion (see `type_converter` below) to a type with defined
35 // argument or result ABI (see Type::ArgumentAbi and Type::ResultAbi).
36 //
37 // If conversion is not possible, calling convention must return a null value.
38 //
39 // Example: abstract kernel defined in high level dialect, e.g. MHLO
40 //
41 //   ```mlir
42 //     func @kernel(%arg0: tensor<?xf32>,
43 //                  %arg1: tensor<?xf32>) -> tensor<?x?xf32> { ... }
44 //   ```
45 //
46 //   after calling convention conversion becomes:
47 //
48 //   ```mlir
49 //     func @kernel(%ctx: !rt.kernel_context,
50 //                  %arg0: memref<?xf32>,
51 //                  %arg1: memref<?xf32>) -> memref<?x?xf32> { ... }
52 //   ```
53 //
54 // Calling convention function type is not the same as the entrypoint function
55 // type produced by the compilation pipeline for several reasons:
56 //
57 // 1) Compilation pipeline produces LLVM functions with LLVM types, and high
58 //    level information is lost, e.g. all memrefs are deconstructed into
59 //    primitive fields when passed as inputs.
60 //
61 // 2) Compiled kernel function always returns void, and uses runtime API to
62 //    return results back to the caller (see `rt-convert-to-entrypoint` pass).
63 //
64 // Calling convention function type is a XLA-compatible description of the
65 // compiled kernel ABI, so that XLA runtime can correctly initialize CallFrame
66 // arguments, allocate memory for returned results, and then correctly decode
67 // results memory into the high level types (e.g. convert returned memref
68 // descriptor to a Tensor).
69 class CallingConvention
70     : public std::function<mlir::FunctionType(mlir::FunctionType)> {
71   using function::function;
72 };
73 
74 // Returns a calling convention that only adds the kernel context argument.
75 CallingConvention DefaultCallingConvention();
76 
77 // Returns a calling convention that uses user-provided type converter to
78 // convert all inputs and results types, and adds the kernel context argument.
79 CallingConvention DefaultCallingConvention(mlir::TypeConverter);
80 
81 // Returns a calling convention that (1) prepends the kernel context argument,
82 // (2) uses the user-provided type converter to convert all inputs and results
83 // types, and (3) converts result types into out-params by appending them
84 // to the arguments.
85 CallingConvention ResultsToOutsCallingConvention(mlir::TypeConverter);
86 
87 }  // namespace runtime
88 }  // namespace xla
89 
90 #endif  // XLA_MLIR_TRANSFORMS_RUNTIME_CALLING_CONVENTION_H_
91