1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_MLIR_TRANSFORMS_RUNTIME_CALLING_CONVENTION_H_ 17 #define XLA_MLIR_TRANSFORMS_RUNTIME_CALLING_CONVENTION_H_ 18 19 #include <functional> 20 21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 22 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project 23 24 namespace xla { 25 namespace runtime { 26 27 // Calling convention converts the XLA entrypoint function type to the function 28 // type with a well defined ABI (e.g. tensors do not have an ABI, and must be 29 // passed across the function boundary as memrefs). In a nutshell it tells the 30 // XLA runtime how to call the compiled executable at run time, and how to 31 // return results back to the caller. 32 // 33 // All types in the converted function signature should have a registered 34 // type conversion (see `type_converter` below) to a type with defined 35 // argument or result ABI (see Type::ArgumentAbi and Type::ResultAbi). 36 // 37 // If conversion is not possible, calling convention must return a null value. 38 // 39 // Example: abstract kernel defined in high level dialect, e.g. MHLO 40 // 41 // ```mlir 42 // func @kernel(%arg0: tensor<?xf32>, 43 // %arg1: tensor<?xf32>) -> tensor<?x?xf32> { ... } 44 // ``` 45 // 46 // after calling convention conversion becomes: 47 // 48 // ```mlir 49 // func @kernel(%ctx: !rt.kernel_context, 50 // %arg0: memref<?xf32>, 51 // %arg1: memref<?xf32>) -> memref<?x?xf32> { ... } 52 // ``` 53 // 54 // Calling convention function type is not the same as the entrypoint function 55 // type produced by the compilation pipeline for several reasons: 56 // 57 // 1) Compilation pipeline produces LLVM functions with LLVM types, and high 58 // level information is lost, e.g. all memrefs are deconstructed into 59 // primitive fields when passed as inputs. 60 // 61 // 2) Compiled kernel function always returns void, and uses runtime API to 62 // return results back to the caller (see `rt-convert-to-entrypoint` pass). 63 // 64 // Calling convention function type is a XLA-compatible description of the 65 // compiled kernel ABI, so that XLA runtime can correctly initialize CallFrame 66 // arguments, allocate memory for returned results, and then correctly decode 67 // results memory into the high level types (e.g. convert returned memref 68 // descriptor to a Tensor). 69 class CallingConvention 70 : public std::function<mlir::FunctionType(mlir::FunctionType)> { 71 using function::function; 72 }; 73 74 // Returns a calling convention that only adds the kernel context argument. 75 CallingConvention DefaultCallingConvention(); 76 77 // Returns a calling convention that uses user-provided type converter to 78 // convert all inputs and results types, and adds the kernel context argument. 79 CallingConvention DefaultCallingConvention(mlir::TypeConverter); 80 81 // Returns a calling convention that (1) prepends the kernel context argument, 82 // (2) uses the user-provided type converter to convert all inputs and results 83 // types, and (3) converts result types into out-params by appending them 84 // to the arguments. 85 CallingConvention ResultsToOutsCallingConvention(mlir::TypeConverter); 86 87 } // namespace runtime 88 } // namespace xla 89 90 #endif // XLA_MLIR_TRANSFORMS_RUNTIME_CALLING_CONVENTION_H_ 91