xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
17 
18 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
19 #include "absl/types/optional.h"
20 #include "absl/types/variant.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Shape/IR/Shape.h"  // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/Dialect.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Transforms/Passes.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
54 #include "tensorflow/compiler/mlir/xla/layout_util.h"
55 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
56 #include "tensorflow/compiler/mlir/xla/transforms/adjust_layout.h"
57 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
58 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
59 #include "tensorflow/compiler/tf2xla/layout_util.h"
60 #include "tensorflow/compiler/tf2xla/shape_util.h"
61 #include "tensorflow/compiler/tf2xla/type_util.h"
62 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
63 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
64 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
65 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
66 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
67 #include "tensorflow/compiler/xla/shape.h"
68 #include "tensorflow/compiler/xla/xla_data.pb.h"
69 #include "tensorflow/core/framework/tensor_shape.h"
70 #include "tensorflow/core/platform/error_payloads.h"
71 #include "tensorflow/core/platform/errors.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
74 #include "tensorflow/core/tpu/tpu_defs.h"
75 
76 namespace tensorflow {
77 namespace {
78 
79 constexpr absl::string_view kGroupSizeAttrName =
80     "tf2xla.collective_info.group_size";
81 constexpr absl::string_view kGroupKeyAttrName =
82     "tf2xla.collective_info.group_key";
83 
84 // Extracts shape from XlaArgument as TensorShape. If shape is a xla::Shape,
85 // that is converted to a TensorShape.
GetTensorShapeFromXlaArgument(const XlaArgument & arg)86 StatusOr<TensorShape> GetTensorShapeFromXlaArgument(const XlaArgument& arg) {
87   if (absl::holds_alternative<xla::Shape>(arg.shape)) {
88     TensorShape arg_shape;
89     TF_RETURN_IF_ERROR(
90         XLAShapeToTensorShape(std::get<xla::Shape>(arg.shape), &arg_shape));
91     return arg_shape;
92   } else {
93     return std::get<TensorShape>(arg.shape);
94   }
95 }
96 
MaybeRewriteLayoutWithShardedShape(mlir::StringAttr sharding,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,xla::Shape * shape)97 Status MaybeRewriteLayoutWithShardedShape(
98     mlir::StringAttr sharding,
99     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
100     xla::Shape* shape) {
101   if (!sharding) return OkStatus();
102 
103   xla::OpSharding op_sharding;
104   if (!op_sharding.ParseFromString(sharding.getValue().str()))
105     return errors::InvalidArgument("failed to parse sharding '",
106                                    sharding.getValue().str(), "'");
107   std::optional<xla::HloSharding> hlo_sharding;
108   TF_ASSIGN_OR_RETURN(hlo_sharding, xla::HloSharding::FromProto(op_sharding));
109   TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
110       hlo_sharding, /*use_fast_memory=*/false, shape_determination_fns, shape));
111   return OkStatus();
112 }
113 
114 // Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
GetXlaInputShapes(mlir::ModuleOp module,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,std::vector<xla::Shape> * xla_input_shapes)115 Status GetXlaInputShapes(
116     mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
117     bool use_tuple_args,
118     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
119     std::vector<xla::Shape>* xla_input_shapes) {
120   xla_input_shapes->clear();
121 
122   mlir::func::FuncOp main_func =
123       module.lookupSymbol<mlir::func::FuncOp>("main");
124   TF_RET_CHECK(main_func != nullptr) << "No main function found";
125   mlir::FunctionType func_type = main_func.getFunctionType();
126 
127   int num_args = func_type.getNumInputs();
128   xla_input_shapes->reserve(num_args);
129 
130   std::vector<xla::Shape> individual_arg_shapes;
131   individual_arg_shapes.reserve(num_args);
132   for (int i = 0; i < num_args; ++i) {
133     individual_arg_shapes.emplace_back();
134     xla::Shape& xla_shape = individual_arg_shapes.back();
135 
136     DataType arg_dtype;
137     TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &arg_dtype));
138 
139     auto layout_preference = shape_determination_fns.layout_preference_fn(
140         arg_shapes[i].shape, arg_dtype, std::nullopt);
141     TF_ASSIGN_OR_RETURN(xla_shape,
142                         shape_determination_fns.shape_representation_fn(
143                             arg_shapes[i].shape, arg_dtype,
144                             /*use_fast_memory=*/false, layout_preference));
145 
146     // Rewrite layout with sharding, if sharding is set.
147     auto sharding =
148         main_func.getArgAttrOfType<mlir::StringAttr>(i, "mhlo.sharding");
149     TF_RETURN_IF_ERROR(MaybeRewriteLayoutWithShardedShape(
150         sharding, shape_determination_fns, &xla_shape));
151   }
152   if (use_tuple_args) {
153     xla_input_shapes->push_back(
154         xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
155   } else {
156     *xla_input_shapes = individual_arg_shapes;
157   }
158   return OkStatus();
159 }
160 
161 // Returns a static ranked tensor type corresponding to the given static or
162 // bounded type by using the bounds as dimension sizes. Returns null if is
163 // neither.
GetBufferType(mlir::Type ty)164 mlir::RankedTensorType GetBufferType(mlir::Type ty) {
165   auto ranked_ty = ty.dyn_cast_or_null<mlir::RankedTensorType>();
166   if (!ranked_ty) return {};
167 
168   int64_t rank = ranked_ty.getRank();
169   llvm::SmallVector<int64_t, 4> dims = llvm::to_vector<4>(ranked_ty.getShape());
170   auto encoding = ranked_ty.getEncoding()
171                       .dyn_cast_or_null<mlir::mhlo::TypeExtensionsAttr>();
172   if (encoding && !encoding.getBounds().empty()) {
173     for (int64_t dim = 0; dim < rank; ++dim) {
174       if (dims[dim] == mlir::ShapedType::kDynamicSize) {
175         dims[dim] = encoding.getBounds()[dim];
176       }
177     }
178   }
179   return mlir::RankedTensorType::get(dims, ranked_ty.getElementType());
180 }
181 
182 // Calculates computation output shape and build OutputDescription for each
183 // output based on static shapes in MLIR module. If an output is a resource
184 // write, `resource_updates` is populated insead of `outputs` for that output.
GetOutputInfo(mlir::ModuleOp module,bool use_resource_updates_for_aliases,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,xla::Shape * xla_output_shape,std::vector<XlaOutputDescription> * outputs,std::vector<XlaResourceUpdate> * resource_updates)185 Status GetOutputInfo(
186     mlir::ModuleOp module, bool use_resource_updates_for_aliases,
187     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
188     xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
189     std::vector<XlaResourceUpdate>* resource_updates) {
190   auto shape_representation_fn_no_fast_memory =
191       [shape_determination_fns](
192           const xla::Shape& xla_shape) -> StatusOr<xla::Shape> {
193     TensorShape shape;
194     TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
195     TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
196                                             xla_shape.element_type()));
197     auto layout_preference = shape_determination_fns.layout_preference_fn(
198         shape, dtype, std::nullopt);
199     return shape_determination_fns.shape_representation_fn(
200         shape, dtype, /*use_fast_memory=*/false, layout_preference);
201   };
202 
203   mlir::func::FuncOp main_func =
204       module.lookupSymbol<mlir::func::FuncOp>("main");
205   mlir::FunctionType func_type = main_func.getFunctionType();
206 
207   outputs->clear();
208   outputs->reserve(func_type.getNumResults());
209   resource_updates->clear();
210   resource_updates->reserve(func_type.getNumResults());
211 
212   std::vector<xla::Shape> shapes;
213   shapes.reserve(func_type.getNumResults());
214 
215   llvm::SmallDenseMap<unsigned, unsigned> output_to_input_alias;
216   for (unsigned i = 0; i < main_func.getNumArguments(); ++i)
217     if (auto aliasing_output = main_func.getArgAttrOfType<mlir::IntegerAttr>(
218             i, "tf.aliasing_output"))
219       output_to_input_alias[aliasing_output.getInt()] = i;
220 
221   auto return_op = main_func.begin()->getTerminator();
222   for (const auto& type_and_idx : llvm::enumerate(func_type.getResults())) {
223     size_t idx = type_and_idx.index();
224     auto result_ty = type_and_idx.value().cast<mlir::RankedTensorType>();
225 
226     // If the result type isn't static, then the owner of the result may be a
227     // cast op from a more specific bounded type to an unbounded dynamic type.
228     // Use the bounded type to get the buffer size.
229     mlir::RankedTensorType buffer_ty = result_ty;
230     if (!buffer_ty.hasStaticShape()) {
231       mlir::Value return_val = return_op->getOperand(idx);
232       if (auto owner = mlir::dyn_cast_or_null<mlir::tensor::CastOp>(
233               return_val.getDefiningOp())) {
234         // For bounded dynamic type, get a static size by taking bounds as the
235         // dimensions. These dimensions are marked as dynamic in xla::Shape
236         // below.
237         buffer_ty = GetBufferType(owner.getOperand().getType());
238         if (!buffer_ty || !buffer_ty.hasStaticShape()) {
239           return errors::InvalidArgument(
240               "results needs to be static or bounded");
241         }
242       }
243     }
244 
245     xla::Shape shape = xla::TypeToShape(buffer_ty);
246     if (shape.element_type() == xla::PRIMITIVE_TYPE_INVALID) {
247       return errors::InvalidArgument("XLA conversion failed for MLIR type.");
248     }
249     TF_ASSIGN_OR_RETURN(shape, shape_representation_fn_no_fast_memory(shape));
250 
251     if (!result_ty.hasStaticShape()) {
252       int64_t rank = result_ty.getRank();
253       for (int64_t dim = 0; dim < rank; ++dim) {
254         if (result_ty.isDynamicDim(dim)) {
255           shape.set_dynamic_dimension(dim, true);
256         }
257       }
258     }
259 
260     auto sharding = main_func.getResultAttrOfType<mlir::StringAttr>(
261         type_and_idx.index(), "mhlo.sharding");
262     TF_RETURN_IF_ERROR(MaybeRewriteLayoutWithShardedShape(
263         sharding, shape_determination_fns, &shape));
264 
265     auto tensor_type = type_and_idx.value().dyn_cast<mlir::RankedTensorType>();
266     shapes.push_back(shape);
267 
268     auto it = output_to_input_alias.find(type_and_idx.index());
269     if (it != output_to_input_alias.end() && use_resource_updates_for_aliases) {
270       // Add resource write.
271       resource_updates->emplace_back();
272       XlaResourceUpdate& resource_update = resource_updates->back();
273       resource_update.input_index = it->getSecond();
274       resource_update.modified = true;
275       TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type));
276       TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape));
277       continue;
278     }
279     // Construct OutputDescription for result.
280     outputs->emplace_back();
281     XlaOutputDescription& out_desc = outputs->back();
282     TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
283     // TODO(ycao): Support constant output.
284     out_desc.is_constant = false;
285     TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &out_desc.shape));
286     // Input_index is only meaningful for resource output. Setting it to
287     // meaningless value -1 for non resource outputs.
288     out_desc.input_index =
289         it != output_to_input_alias.end() ? it->getSecond() : -1;
290     // MLIR-based TF-Compiler bridge doesn't support tensorlist output yet.
291     // TODO(ycao): Support tensorlist-type output.
292     out_desc.is_tensor_list = false;
293   }
294 
295   // XLA computation always uses Tuple shape.
296   *xla_output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
297   return OkStatus();
298 }
299 
300 // Creates a vector that maps from the parameters of the XLA computation to
301 // their original argument positions.
302 // MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no
303 // inputs are known constants. Therefore, the input mapping between input to
304 // computation arguments is a trivial in-order 1-1 mapping.
305 // TODO(ycao): Support computation with compile-time constant, which requires
306 // non-trivial input mapping as implemented now.
GetInputMappingForMlir(int num_inputs,std::vector<int> * input_mapping)307 void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
308   input_mapping->resize(num_inputs, 0);
309   std::iota(input_mapping->begin(), input_mapping->end(), 0);
310 }
311 
RegisterDialects(mlir::DialectRegistry & registry)312 static void RegisterDialects(mlir::DialectRegistry& registry) {
313   mlir::RegisterAllTensorFlowDialects(registry);
314   mlir::mhlo::registerAllMhloDialects(registry);
315 }
316 
317 // Checks if functions can be inlined after TF -> HLO legalization. Currently
318 // TPU's are supported, to follow the behavior of inlining functions via the
319 // Graph based bridge in the TPUCompile op kernel.
CanInlineFunctionsPostLegalization(llvm::StringRef device_type)320 bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
321   return device_type == DEVICE_TPU_XLA_JIT;
322 }
323 
324 }  //  namespace
325 
RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,mlir::ModuleOp module)326 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
327                     mlir::ModuleOp module) {
328   auto producer_or = GetTfGraphProducerVersion(module);
329   if (!producer_or.ok()) return producer_or.status();
330   int64_t producer_version = producer_or.ValueOrDie();
331 
332   llvm::SmallVector<int64_t, 16> shape_backing;
333   llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> arg_shapes_copy;
334   {
335     // Convert arg_shapes to a mlir friendly format.
336     size_t count = 0;
337     for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
338       if (tensor_resource_shape.is_resource) continue;
339       count += tensor_resource_shape.shape.dims();
340     }
341     shape_backing.resize(count);
342     arg_shapes_copy.reserve(arg_shapes.size());
343     size_t offset = 0;
344     for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
345       if (tensor_resource_shape.is_resource) {
346         arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
347         continue;
348       }
349       size_t start = offset;
350       for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) {
351         shape_backing[offset] = dim.size;
352         ++offset;
353       }
354       if (offset == start) {
355         arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
356       } else {
357         arg_shapes_copy.push_back(
358             llvm::ArrayRef<int64_t>(&shape_backing[start], offset - start));
359       }
360     }
361   }
362 
363   auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
364 
365   mlir::StatusScopedDiagnosticHandler error_handler(module.getContext());
366   mlir::LogicalResult result = mlir::TF::InferShapeForFunction(
367       main_func, arg_shapes_copy, producer_version);
368 
369   if (failed(result)) {
370     return error_handler.Combine(
371         errors::Internal("MLIR Shape refinement failed"));
372   }
373   return error_handler.ConsumeStatus();
374 }
375 
CreateConvertMlirToXlaHloPipeline(mlir::OpPassManager & pm,llvm::StringRef device_type,bool prefer_tf2xla,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes,bool allow_partial_conversion)376 void CreateConvertMlirToXlaHloPipeline(
377     mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla,
378     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
379         custom_legalization_passes,
380     bool allow_partial_conversion) {
381   // Note that the region-based control-flow produced here still contains
382   // function call ops which get inlined by the subsequent inliner pass.
383   pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
384   pm.addPass(mlir::createInlinerPass());
385   pm.addNestedPass<mlir::func::FuncOp>(
386       mlir::TF::CreateDropWhileShapeInvariantPass());
387   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
388   // The SCCP pass performs constant propagation across the IR, which, for
389   // example, propagates constant arguments into callee functions.
390   // TOOD(hinsu): Investigate if we really need SCCP pass before shape inference
391   // and can do with just one pass after the shape inference.
392   pm.addPass(mlir::createSCCPPass());
393   // Guarantee all functions have one use, which enables shape inference.
394   pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
395   // Run shape inference pass before tensorlist decomposition to get buffer
396   // shape of uninitialized TensorLists.
397   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
398 
399   // Run SCCP pass again as the availability of shapes may open up new
400   // opportunities for constant propagation. Note that the shape inference pass
401   // doesn't materialize new constants even if those are computed internally for
402   // the purpose of shape inference. These constants might be required by the
403   // legalization passes.
404   pm.addPass(mlir::createSCCPPass());
405 
406   pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
407   pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
408   pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
409   pm.addNestedPass<mlir::func::FuncOp>(
410       mlir::TFDevice::CreateDecomposeResourceOpsPass());
411   pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
412   pm.addPass(mlir::createSymbolDCEPass());
413   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
414   // TODO(b/171426148): We cannot completely remove region to functional control
415   // flow conversion from this pipeline yet as it causes some unit tests to
416   // fail.
417   pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
418   // LegalizeTFControlFlow encapsulates arguments for control flow operations
419   // with a tuple argument which break the assumption of resource lifting
420   // inside PromoteResourcesToArgs.
421   pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
422 
423   pm.addNestedPass<mlir::func::FuncOp>(mlir::TF::CreateLowerQuantizedPass());
424   pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
425   pm.addPass(mlir::mhlo::createLegalizeTFModulePass(
426       /*tf2xla_fallback_device_type=*/device_type));
427   pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeTFPass(
428       /*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
429       /*tf2xla_fallback_device_type=*/device_type, prefer_tf2xla));
430   for (auto& target_pass : custom_legalization_passes) {
431     pm.addNestedPass<mlir::func::FuncOp>(std::move(target_pass));
432   }
433   pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::CreateAdjustLayoutPass());
434   pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass());
435   pm.addPass(mlir::mhlo::CreateLegalizeTFCollectivePass());
436   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
437   // Run shape inference pass to propagate shapes through tensor_cast operations
438   // from static to dynamic shapes. This could be generated if the shape
439   // inference was originally missing in a TF op but the corresponding HLO op
440   // had static shape after lowering.
441   pm.addPass(mlir::TF::CreateTFShapeInferencePass());
442   // Run LegalizeTFPass again because the previous legalization passes can
443   // expose more graph pruning and canonicalization opportunities that are
444   // necessary for the second LegalizeTFPass(allow_partial_conversion=false)
445   // invocation.
446   pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeTFPass(
447       /*allow_partial_conversion=*/allow_partial_conversion,
448       /*legalize_chlo=*/true,
449       /*tf2xla_fallback_device_type=*/device_type, prefer_tf2xla));
450 
451   if (CanInlineFunctionsPostLegalization(device_type))
452     pm.addPass(mlir::createInlinerPass());
453 
454   // In order to export to XLA, we must sink constants to control flow regions,
455   // since XLA uses functional control flow.
456   pm.addNestedPass<mlir::func::FuncOp>(
457       mlir::mhlo::createSinkConstantsToControlFlowPass());
458 }
459 
LegalizeToHlo(mlir::ModuleOp module_op,llvm::StringRef device_type,bool prefer_tf2xla,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)460 Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
461                      bool prefer_tf2xla,
462                      llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
463                          custom_legalization_passes) {
464   mlir::PassManager tf2xla(module_op.getContext());
465   applyTensorflowAndCLOptions(tf2xla);
466   CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, prefer_tf2xla,
467                                     custom_legalization_passes);
468 
469   if (VLOG_IS_ON(1))
470     tensorflow::DumpMlirOpToFile("legalize_hlo_before", module_op, "", &tf2xla);
471   if (VLOG_IS_ON(2)) {
472     // Print the whole module after each pass which requires disabling
473     // multi-threading as well.
474     module_op.getContext()->disableMultithreading();
475     tf2xla.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
476         /*print_module_scope=*/true));
477   }
478 
479   // Make sure we catch any error reported by MLIR and forward it to the TF
480   // error reporting system. Report a generic error if pass manager failed
481   // without emitting a diagnostic.
482   mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext());
483 
484   if (failed(tf2xla.run(module_op))) {
485     Status status = errors::InvalidArgument("TF to XLA legalization failed: ");
486     tensorflow::OkOrSetErrorCounterPayload(
487         tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_2,
488         status);
489     return error_handler.Combine(status);
490   }
491 
492   if (VLOG_IS_ON(1))
493     tensorflow::DumpMlirOpToFile("legalize_hlo_after", module_op, "", &tf2xla);
494   Status status = error_handler.ConsumeStatus();
495   tensorflow::OkOrSetErrorCounterPayload(
496       tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_2,
497       status);
498   return status;
499 }
500 
BuildHloFromTfInner(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)501 Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
502                            llvm::ArrayRef<xla::XlaOp> xla_params,
503                            std::vector<xla::XlaOp>& returns,
504                            llvm::StringRef device_type,
505                            llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
506                                custom_legalization_passes) {
507   TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type,
508                                    /*prefer_tf2xla=*/false,
509                                    custom_legalization_passes));
510 
511   mlir::Block& block =
512       module_op.lookupSymbol<mlir::func::FuncOp>("main").front();
513   return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
514 }
515 
ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,llvm::StringRef device_type,xla::XlaComputation * xla_computation,bool use_tuple_args,bool prefer_tf2xla,bool return_tuple,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)516 Status ConvertMLIRToXlaComputation(
517     mlir::ModuleOp module_op, llvm::StringRef device_type,
518     xla::XlaComputation* xla_computation, bool use_tuple_args,
519     bool prefer_tf2xla, bool return_tuple,
520     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
521     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
522         custom_legalization_passes) {
523   TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type, prefer_tf2xla,
524                                    custom_legalization_passes));
525 
526   mlir::MlirToHloConversionOptions options;
527   options.layout_preference_fn =
528       [&](const xla::Shape& xla_shape) -> StatusOr<mlir::XlaLayoutPreference> {
529     TensorShape shape;
530     TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
531     TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
532                                             xla_shape.element_type()));
533     return shape_determination_fns.layout_preference_fn(shape, dtype,
534                                                         std::nullopt);
535   };
536   options.shape_representation_fn =
537       [&](const xla::Shape& xla_shape, bool fast_mem,
538           mlir::XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
539     TensorShape shape;
540     TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
541     TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
542                                             xla_shape.element_type()));
543     return shape_determination_fns.shape_representation_fn(
544         shape, dtype, fast_mem, layout_preference);
545   };
546   xla::HloProto hlo_proto;
547   TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(
548       module_op, &hlo_proto, use_tuple_args, return_tuple, options));
549   *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
550   return OkStatus();
551 }
552 
CompileMlirSetup(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes)553 Status CompileMlirSetup(mlir::ModuleOp module_op,
554                         llvm::ArrayRef<TensorOrResourceShape> arg_shapes) {
555   // Use arg_shapes to improve the mlir type information of `main` in module_op.
556   TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
557 
558   if (VLOG_IS_ON(2))
559     tensorflow::DumpMlirOpToFile("compile_mlir_shape_refiner", module_op);
560 
561   return OkStatus();
562 }
563 
BuildHloFromTf(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)564 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
565                       llvm::ArrayRef<xla::XlaOp> xla_params,
566                       std::vector<xla::XlaOp>& returns,
567                       llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
568                       llvm::StringRef device_type,
569                       llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
570                           custom_legalization_passes) {
571   if (VLOG_IS_ON(2))
572     tensorflow::DumpMlirOpToFile("build_hlo_tf_before", module_op);
573 
574   TF_RETURN_IF_ERROR(CompileMlirSetup(module_op, arg_shapes));
575 
576   // Convert MLIR module to XLA HLO proto contained in XlaComputation.
577   TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
578                                          returns, device_type,
579                                          custom_legalization_passes));
580 
581   if (VLOG_IS_ON(2))
582     tensorflow::DumpMlirOpToFile("build_hlo_tf_after", module_op);
583 
584   return OkStatus();
585 }
586 
PopulateCollectiveInfo(mlir::ModuleOp module_op,XlaCompilationResult * compilation_result)587 Status PopulateCollectiveInfo(mlir::ModuleOp module_op,
588                               XlaCompilationResult* compilation_result) {
589   // The StringRef cast is necessary before cxx14.
590   mlir::IntegerAttr group_key_attr =
591       module_op->getAttrOfType<mlir::IntegerAttr>(
592           mlir::StringRef(kGroupKeyAttrName.data(), kGroupKeyAttrName.size()));
593   mlir::IntegerAttr group_size_attr =
594       module_op->getAttrOfType<mlir::IntegerAttr>(mlir::StringRef(
595           kGroupSizeAttrName.data(), kGroupSizeAttrName.size()));
596   if (group_key_attr == nullptr && group_size_attr == nullptr) {
597     // No CollectiveInfo is present.
598     return OkStatus();
599   }
600   DCHECK(group_key_attr != nullptr)
601       << "module attribute " << kGroupKeyAttrName
602       << " is required for CollectiveInfo but not found.";
603   DCHECK(group_size_attr != nullptr)
604       << "module attribute " << kGroupSizeAttrName
605       << " is required for CollectiveInfo but not found.";
606   int32_t group_key = group_key_attr.getInt();
607   int32_t group_size = group_size_attr.getInt();
608   VLOG(2) << "Populating CollectiveInfo: group_key=" << group_key
609           << " group_size=" << group_size;
610   compilation_result->collective_info = {group_key, group_size, 0};
611   return OkStatus();
612 }
613 
PopulateResultIOInfo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,bool use_resource_updates_for_aliases,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result)614 Status PopulateResultIOInfo(
615     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
616     bool use_tuple_args, bool use_resource_updates_for_aliases,
617     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
618     XlaCompilationResult* compilation_result) {
619   // Construct mapping from XlaComputation's arg to input edges of execute
620   // node.
621   GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
622 
623   // Compute all input shapes.
624   TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
625                                        shape_determination_fns,
626                                        &compilation_result->xla_input_shapes));
627 
628   // Compute all output descriptions and resource writes
629   return GetOutputInfo(
630       module_op, use_resource_updates_for_aliases, shape_determination_fns,
631       &compilation_result->xla_output_shape, &compilation_result->outputs,
632       &compilation_result->resource_updates);
633 }
634 
CompileMlirToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,bool use_return_tuple,bool use_resource_updates_for_aliases,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)635 Status CompileMlirToXlaHlo(
636     mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
637     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
638     bool use_return_tuple, bool use_resource_updates_for_aliases,
639     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
640     XlaCompilationResult* compilation_result,
641     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
642         custom_legalization_passes) {
643   if (analyse_graph &&
644       GetMlirBridge2ndPhaseRolloutPolicy(module_op) ==
645           MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
646     return CompileToHloGraphAnalysisFailedError();
647   }
648 
649   TF_RETURN_IF_ERROR(CompileMlirSetup(module_op, arg_shapes));
650 
651   // Convert MLIR module to XLA HLO proto contained in XlaComputation.
652   compilation_result->computation = std::make_shared<xla::XlaComputation>();
653   TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
654       module_op, device_type, compilation_result->computation.get(),
655       use_tuple_args, analyse_graph, use_return_tuple, shape_determination_fns,
656       custom_legalization_passes));
657 
658   TF_RETURN_IF_ERROR(PopulateCollectiveInfo(module_op, compilation_result));
659 
660   return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
661                               use_resource_updates_for_aliases,
662                               shape_determination_fns, compilation_result);
663 }
664 
CompileSerializedMlirToXlaHlo(llvm::StringRef mlir_module_string,llvm::ArrayRef<TensorShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)665 Status CompileSerializedMlirToXlaHlo(
666     llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
667     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
668     const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
669     XlaCompilationResult* compilation_result,
670     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
671         custom_legalization_passes) {
672   mlir::DialectRegistry mlir_registry;
673   RegisterDialects(mlir_registry);
674   mlir::MLIRContext mlir_context(mlir_registry);
675   mlir::OwningOpRef<mlir::ModuleOp> mlir_module;
676 
677   TF_RETURN_IF_ERROR(
678       DeserializeMlirModule(mlir_module_string, &mlir_context, &mlir_module));
679   llvm::SmallVector<TensorOrResourceShape, 4> tensor_or_resource_shapes;
680   tensor_or_resource_shapes.reserve(arg_shapes.size());
681   for (const auto& arg_shape : arg_shapes)
682     tensor_or_resource_shapes.push_back({arg_shape});
683   return CompileMlirToXlaHlo(
684       mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
685       analyse_graph, /*use_return_tuple=*/true,
686       /*use_resource_updates_for_aliases=*/false, shape_determination_fns,
687       compilation_result, custom_legalization_passes);
688 }
689 
690 // Rewrites the given module with specified args. For each of the constant args,
691 // it gets inlined in the "main' function and the corresponding argument is
692 // removed from the signature. For resource args, their subtypes are populated.
693 // Returns the original indices for the other arguments on success.
RewriteWithArgs(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args)694 static StatusOr<std::vector<int>> RewriteWithArgs(
695     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args) {
696   mlir::func::FuncOp main_fn =
697       module_op.lookupSymbol<mlir::func::FuncOp>("main");
698   std::vector<int> params;
699 
700   bool has_resource_args = false;
701   auto builder = mlir::OpBuilder(main_fn.getBody());
702   std::vector<int> args_to_erase;
703   for (int idx = 0; idx < args.size(); idx++) {
704     const XlaArgument& xla_arg = args[idx];
705     mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
706     if (xla_arg.kind == XlaArgument::kResource) {
707       mlir::Type element_type;
708       if (xla_arg.type == DT_INVALID) {
709         return errors::Unimplemented(absl::StrCat(
710             "Argument ", idx,
711             " is an uninitialized resource variable which is currently"
712             " unsupported in the MLIR-based TPU bridge"));
713       }
714       TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type));
715       TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
716                           GetTensorShapeFromXlaArgument(xla_arg));
717       auto resource_shape = arg_shape.dim_sizes();
718       llvm::SmallVector<int64_t, 4> resource_subtype_shape(
719           resource_shape.begin(), resource_shape.end());
720       auto resource_subtype =
721           mlir::RankedTensorType::get(resource_subtype_shape, element_type);
722       auto resource_type =
723           mlir::TF::ResourceType::get({resource_subtype}, builder.getContext());
724 
725       auto tensor_type = mlir_arg.getType().cast<mlir::TensorType>();
726       if (tensor_type.hasRank()) {
727         mlir_arg.setType(
728             mlir::RankedTensorType::get(tensor_type.getShape(), resource_type));
729       } else {
730         mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type));
731       }
732       has_resource_args = true;
733     }
734     if (xla_arg.kind != XlaArgument::kConstant) {
735       params.push_back(idx);
736       continue;
737     }
738 
739     TF_ASSIGN_OR_RETURN(auto value_attr,
740                         ConvertTensor(xla_arg.constant_value, &builder));
741     // TODO(hinsu): Use the actual location of the constant.
742     auto constant = builder.create<mlir::TF::ConstOp>(
743         mlir::UnknownLoc::get(module_op.getContext()), value_attr);
744     mlir_arg.replaceAllUsesWith(constant);
745     args_to_erase.push_back(idx);
746   }
747 
748   if (has_resource_args) {
749     llvm::SmallVector<mlir::Type, 4> updated_argument_types;
750     updated_argument_types.reserve(main_fn.getNumArguments());
751     for (mlir::BlockArgument& arg : main_fn.getArguments())
752       updated_argument_types.push_back(arg.getType());
753 
754     main_fn.setType(
755         mlir::FunctionType::get(main_fn.getContext(), updated_argument_types,
756                                 main_fn.getFunctionType().getResults()));
757   }
758 
759   for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
760 
761   return params;
762 }
763 
CompileGraphSetup(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,std::vector<int> * remaining_params,llvm::SmallVector<TensorOrResourceShape,4> & arg_shapes)764 Status CompileGraphSetup(
765     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
766     std::vector<int>* remaining_params,
767     llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
768   TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
769   arg_shapes.reserve(remaining_params->size());
770   for (unsigned idx : *remaining_params) {
771     const auto& arg = args[idx];
772     TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
773                         GetTensorShapeFromXlaArgument(arg));
774     arg_shapes.push_back({arg_shape,
775                           /*is_resource=*/arg.kind == XlaArgument::kResource});
776   }
777 
778   mlir::PassManager pm(module_op.getContext());
779   applyTensorflowAndCLOptions(pm);
780   mlir::TF::StandardPipelineOptions tf_options;
781   mlir::TF::CreateTFStandardPipeline(pm, tf_options);
782 
783   if (VLOG_IS_ON(1))
784     tensorflow::DumpMlirOpToFile("compile_graph_setup_before", module_op);
785   if (VLOG_IS_ON(2)) {
786     // Print the whole module after each pass which requires disabling
787     // multi-threading as well.
788     module_op.getContext()->disableMultithreading();
789     pm.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
790         /*print_module_scope=*/true));
791   }
792 
793   mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
794   if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
795   if (VLOG_IS_ON(1))
796     tensorflow::DumpMlirOpToFile("compile_graph_setup_after", module_op);
797 
798   return OkStatus();
799 }
800 
BuildHloFromModule(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)801 Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
802                           llvm::ArrayRef<xla::XlaOp> xla_params,
803                           std::vector<xla::XlaOp>& returns,
804                           llvm::ArrayRef<XlaArgument> args,
805                           llvm::StringRef device_type,
806                           llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
807                               custom_legalization_passes) {
808   std::vector<int> remaining_params;
809   llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
810   TF_RETURN_IF_ERROR(
811       CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
812   // Passing down only remaining (non-constant) xla_params.
813   llvm::SmallVector<xla::XlaOp, 2> remaining_xla_params;
814   for (auto i : remaining_params) remaining_xla_params.push_back(xla_params[i]);
815   return BuildHloFromTf(module_op, builder, remaining_xla_params, returns,
816                         arg_shapes, device_type, custom_legalization_passes);
817 }
818 
CompileGraphToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,bool use_return_tuple,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)819 Status CompileGraphToXlaHlo(
820     mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
821     llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
822     bool use_return_tuple,
823     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
824     XlaCompilationResult* compilation_result,
825     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
826         custom_legalization_passes) {
827   std::vector<int> remaining_params;
828   llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
829   TF_RETURN_IF_ERROR(
830       CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
831 
832   auto status = CompileMlirToXlaHlo(
833       module_op, arg_shapes, device_type, use_tuple_args, analyse_graph,
834       use_return_tuple,
835       /*use_resource_updates_for_aliases=*/true, shape_determination_fns,
836       compilation_result, custom_legalization_passes);
837   compilation_result->input_mapping = remaining_params;
838   return status;
839 }
840 
GraphToModule(const Graph & graph,llvm::ArrayRef<std::string> control_rets,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,mlir::MLIRContext * context)841 xla::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphToModule(
842     const Graph& graph, llvm::ArrayRef<std::string> control_rets,
843     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
844     mlir::MLIRContext* context) {
845   mlir::DialectRegistry registry;
846   RegisterDialects(registry);
847   context->appendDialectRegistry(registry);
848   GraphImportConfig config;
849   config.graph_as_function = true;
850   config.control_outputs = control_rets;
851   // Disable shape inference during import as some TensorFlow op fails during
852   // shape inference with dynamic shaped operands. This in turn causes the
853   // import to fail. Shape inference during import is going to be removed and
854   // the shape inference pass is run early in the pass pipeline, shape inference
855   // during import is not necessary.
856   config.enable_shape_inference = false;
857   return ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
858 }
859 
BuildHloFromGraph(const Graph & graph,xla::XlaBuilder & builder,mlir::MLIRContext & mlir_context,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)860 Status BuildHloFromGraph(
861     const Graph& graph, xla::XlaBuilder& builder,
862     mlir::MLIRContext& mlir_context, llvm::ArrayRef<xla::XlaOp> xla_params,
863     std::vector<xla::XlaOp>& returns, llvm::ArrayRef<XlaArgument> args,
864     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
865     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
866     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
867         custom_legalization_passes) {
868   TF_ASSIGN_OR_RETURN(
869       mlir::OwningOpRef<mlir::ModuleOp> module,
870       GraphToModule(graph, control_rets, flib_def, debug_info, &mlir_context));
871   return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
872                             device_type, custom_legalization_passes);
873 }
874 
CompileGraphToXlaHlo(const Graph & graph,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)875 Status CompileGraphToXlaHlo(
876     const Graph& graph, llvm::ArrayRef<XlaArgument> args,
877     llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
878     bool use_tuple_args, bool analyse_graph,
879     const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
880     XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
881     XlaCompilationResult* compilation_result,
882     llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
883         custom_legalization_passes) {
884   mlir::MLIRContext context;
885   TF_ASSIGN_OR_RETURN(
886       mlir::OwningOpRef<mlir::ModuleOp> module,
887       GraphToModule(graph, control_rets, flib_def, debug_info, &context));
888   return CompileGraphToXlaHlo(
889       module.get(), args, device_type, use_tuple_args, analyse_graph,
890       /*use_return_tuple=*/true, shape_determination_fns, compilation_result,
891       custom_legalization_passes);
892 }
893 
RegisterConvertMlirToXlaHloPipelineWithDefaults()894 void RegisterConvertMlirToXlaHloPipelineWithDefaults() {
895   static mlir::PassPipelineRegistration<> pipeline(
896       "tf-to-hlo-pipeline",
897       "Convert TF dialect to HLO dialect (used for compilation in bridge).",
898       [](mlir::OpPassManager& pm) {
899         tensorflow::CreateConvertMlirToXlaHloPipeline(
900             pm, /*device_type=*/"XLA_CPU_JIT", /*prefer_tf2xla=*/false,
901             /*custom_legalization_passes=*/{});
902       });
903 }
904 
905 }  // namespace tensorflow
906