xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
18 
19 #include <memory>
20 
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Pass/PassManager.h"  // from @llvm-project
27 #include "tensorflow/compiler/tf2xla/layout_util.h"
28 #include "tensorflow/compiler/tf2xla/xla_argument.h"
29 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/core/common_runtime/device.h"
32 #include "tensorflow/core/framework/tensor_shape.h"
33 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
34 #include "tensorflow/stream_executor/lib/statusor.h"
35 
36 namespace tensorflow {
37 
38 // Populates the supplied passmanager with the passes required to run the
39 // TF MLIR to XLA HLO MLIR conversion/legalization. Custom legalization passes
40 // can be populated in `custom_legalization_passes`.
41 void CreateConvertMlirToXlaHloPipeline(
42     mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla,
43     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
44         custom_legalization_passes,
45     bool allow_partial_conversion = false);
46 
47 // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
48 // should only contain operations in tf dialect. If the input module contains
49 // operation in the tf_executor dialect, for example, returns an error.
50 // Exception to this are tf_executor dialect ops that are optimized away through
51 // canonicalization.
52 //
53 // Operations in tf dialect are lowered to XLA HLO through the following steps:
54 //   . Legalizes control flow operations.
55 //   . Decomposes compound resource operations so that the only remaining
56 //     operations on resource variables are resource reads/writes..
57 //   . Replaces resource reads/writes with function inputs/outputs and
58 //     eliminates the use of resource variables.
59 //   . Legalizes the operations to XLA HLO operations.
60 //   . Canonicalizes the XLA HLO operations.
61 //
62 // device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT",
63 //   "XLA_GPU_JIT" or "XLA_TPU_JIT".
64 // use_tuple_args: when this is true, always create a tuple argument for the
65 //   entry computation.
66 // prefer_tf2xla: when this is true, prefer tf2xla fallback kernels over MLIR
67 //   native kernels for legalization to HLO.
68 // return_tuple: when this is true, always create a tuple result for the
69 //   entry computation.
70 // shape_determination_fns: Contains layout preference fn and shape
71 //   representation fn. The two functions are used to determine argument and
72 //   result shapes.
73 // custom_legalization_passes: passes to run before the default TF legalization
74 //   passes for backend-specific ops.
75 //
76 // TODO(hinsu): Migrate options to a separate struct.
77 Status ConvertMLIRToXlaComputation(
78     mlir::ModuleOp module_op, llvm::StringRef device_type,
79     xla::XlaComputation* xla_computation, bool use_tuple_args,
80     bool prefer_tf2xla, bool return_tuple,
81     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns =
82         {},
83     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
84         custom_legalization_passes = {});
85 
86 // Helper struct representing argument tensor or resource handle shapes.
87 struct TensorOrResourceShape {
88   TensorShape shape;
89   bool is_resource = false;
90 };
91 
92 // Refine MLIR types based on new shape information.
93 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
94                     mlir::ModuleOp module);
95 
96 // Lower TF to MHLO and insert HLO into the XlaBuilder. xla_params are HLO-level
97 // inputs to module_op that have already been added to the XlaBuilder. returns
98 // are the returned XlaOps.
99 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
100                       llvm::ArrayRef<xla::XlaOp> xla_params,
101                       std::vector<xla::XlaOp>& returns,
102                       llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
103                       llvm::StringRef device_type,
104                       llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
105                           custom_legalization_passes);
106 
107 // Apply shape, description, and resource information to inputs and outputs
108 // in the XlaCompilationResult. This should be called after
109 // compilation_result->computation was set.
110 Status PopulateResultIOInfo(
111     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
112     bool use_tuple_args, bool use_resource_updates_for_aliases,
113     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
114     XlaCompilationResult* compilation_result);
115 
116 // Compiles a MLIR module into XLA HLO, generates all accompanying metadata and
117 // stores them in CompilationResult.
118 //
119 // If analyse_graph is set to true, graph is legalized only if the graph
120 // analysis for the graph is successful. Otherwise, an error is returned.
121 Status CompileMlirToXlaHlo(
122     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
123     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
124     bool use_return_tuple, bool use_resource_updates_for_aliases,
125     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
126     XlaCompilationResult* compilation_result,
127     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
128         custom_legalization_passes);
129 
130 // Compiles a serialized MLIR module into XLA HLO, generates all accompanying
131 // metadata and stores them in CompilationResult.
132 Status CompileSerializedMlirToXlaHlo(
133     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
134     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
135     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
136     XlaCompilationResult* compilation_result,
137     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
138         custom_legalization_passes = {});
139 
140 // Compiles a TensorFlow Graph (already converted to MLIR, imported with
141 // tf_executor dialect still present) into XLA HLO, generates all accompanying
142 // metadata and stores them in CompilationResult. This will rewrite arguments
143 // and run the TensorFlow standard pipeline prior to invoking
144 // `CompileMlirToXlaHlo`.
145 Status CompileGraphToXlaHlo(
146     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
147     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
148     bool use_return_tuple,
149     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
150     XlaCompilationResult* compilation_result,
151     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
152         custom_legalization_passes);
153 
154 // Compiles a TensorFlow Graph into XLA HLO, generates all accompanying metadata
155 // and stores them in CompilationResult.
156 Status CompileGraphToXlaHlo(
157     const Graph& graph, llvm::ArrayRef<XlaArgument> args,
158     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
159     bool use_tuple_args, bool analyse_graph,
160     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
161     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
162     XlaCompilationResult* compilation_result,
163     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
164         custom_legalization_passes = {});
165 
166 // Compiles a Graph from TF to HLO and adds the resulting HLO to the
167 // XlaBuilder. This function adds HLO to a larger HLO computation, so
168 // HLO-level inputs are supplied, and HLO-level outputs are produced.
169 // xla_params is the HLO-level inputs and returns is the HLO-level outputs.
170 Status BuildHloFromGraph(
171     const Graph& graph, xla::XlaBuilder& builder,
172     mlir::MLIRContext& mlir_context, llvm::ArrayRef<xla::XlaOp> xla_params,
173     std::vector<xla::XlaOp>& returns, llvm::ArrayRef<XlaArgument> args,
174     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
175     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
176     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
177         custom_legalization_passes = {});
178 
CompileToHloGraphAnalysisFailedError()179 static inline Status CompileToHloGraphAnalysisFailedError() {
180   return errors::Internal("disabled after graph analysis");
181 }
182 
183 // Register a convenient pipeline for invoking TF/XLA lowering from the command
184 // line.
185 void RegisterConvertMlirToXlaHloPipelineWithDefaults();
186 
187 }  // namespace tensorflow
188 
189 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_COMPILE_MLIR_UTIL_H_
190