1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
17
18 #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h"
19 #include "absl/types/optional.h"
20 #include "absl/types/variant.h"
21 #include "llvm/ADT/ArrayRef.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
29 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/Dialect.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/OpDefinition.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "mlir/Transforms/Passes.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
41 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
43 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
44 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
45 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
46 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
50 #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
51 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
52 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
53 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
54 #include "tensorflow/compiler/mlir/xla/layout_util.h"
55 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
56 #include "tensorflow/compiler/mlir/xla/transforms/adjust_layout.h"
57 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
58 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
59 #include "tensorflow/compiler/tf2xla/layout_util.h"
60 #include "tensorflow/compiler/tf2xla/shape_util.h"
61 #include "tensorflow/compiler/tf2xla/type_util.h"
62 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
63 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
64 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/register.h"
65 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
66 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
67 #include "tensorflow/compiler/xla/shape.h"
68 #include "tensorflow/compiler/xla/xla_data.pb.h"
69 #include "tensorflow/core/framework/tensor_shape.h"
70 #include "tensorflow/core/platform/error_payloads.h"
71 #include "tensorflow/core/platform/errors.h"
72 #include "tensorflow/core/platform/logging.h"
73 #include "tensorflow/core/protobuf/core_platform_payloads.pb.h"
74 #include "tensorflow/core/tpu/tpu_defs.h"
75
76 namespace tensorflow {
77 namespace {
78
79 constexpr absl::string_view kGroupSizeAttrName =
80 "tf2xla.collective_info.group_size";
81 constexpr absl::string_view kGroupKeyAttrName =
82 "tf2xla.collective_info.group_key";
83
84 // Extracts shape from XlaArgument as TensorShape. If shape is a xla::Shape,
85 // that is converted to a TensorShape.
GetTensorShapeFromXlaArgument(const XlaArgument & arg)86 StatusOr<TensorShape> GetTensorShapeFromXlaArgument(const XlaArgument& arg) {
87 if (absl::holds_alternative<xla::Shape>(arg.shape)) {
88 TensorShape arg_shape;
89 TF_RETURN_IF_ERROR(
90 XLAShapeToTensorShape(std::get<xla::Shape>(arg.shape), &arg_shape));
91 return arg_shape;
92 } else {
93 return std::get<TensorShape>(arg.shape);
94 }
95 }
96
MaybeRewriteLayoutWithShardedShape(mlir::StringAttr sharding,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,xla::Shape * shape)97 Status MaybeRewriteLayoutWithShardedShape(
98 mlir::StringAttr sharding,
99 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
100 xla::Shape* shape) {
101 if (!sharding) return OkStatus();
102
103 xla::OpSharding op_sharding;
104 if (!op_sharding.ParseFromString(sharding.getValue().str()))
105 return errors::InvalidArgument("failed to parse sharding '",
106 sharding.getValue().str(), "'");
107 std::optional<xla::HloSharding> hlo_sharding;
108 TF_ASSIGN_OR_RETURN(hlo_sharding, xla::HloSharding::FromProto(op_sharding));
109 TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
110 hlo_sharding, /*use_fast_memory=*/false, shape_determination_fns, shape));
111 return OkStatus();
112 }
113
114 // Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
GetXlaInputShapes(mlir::ModuleOp module,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,std::vector<xla::Shape> * xla_input_shapes)115 Status GetXlaInputShapes(
116 mlir::ModuleOp module, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
117 bool use_tuple_args,
118 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
119 std::vector<xla::Shape>* xla_input_shapes) {
120 xla_input_shapes->clear();
121
122 mlir::func::FuncOp main_func =
123 module.lookupSymbol<mlir::func::FuncOp>("main");
124 TF_RET_CHECK(main_func != nullptr) << "No main function found";
125 mlir::FunctionType func_type = main_func.getFunctionType();
126
127 int num_args = func_type.getNumInputs();
128 xla_input_shapes->reserve(num_args);
129
130 std::vector<xla::Shape> individual_arg_shapes;
131 individual_arg_shapes.reserve(num_args);
132 for (int i = 0; i < num_args; ++i) {
133 individual_arg_shapes.emplace_back();
134 xla::Shape& xla_shape = individual_arg_shapes.back();
135
136 DataType arg_dtype;
137 TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &arg_dtype));
138
139 auto layout_preference = shape_determination_fns.layout_preference_fn(
140 arg_shapes[i].shape, arg_dtype, std::nullopt);
141 TF_ASSIGN_OR_RETURN(xla_shape,
142 shape_determination_fns.shape_representation_fn(
143 arg_shapes[i].shape, arg_dtype,
144 /*use_fast_memory=*/false, layout_preference));
145
146 // Rewrite layout with sharding, if sharding is set.
147 auto sharding =
148 main_func.getArgAttrOfType<mlir::StringAttr>(i, "mhlo.sharding");
149 TF_RETURN_IF_ERROR(MaybeRewriteLayoutWithShardedShape(
150 sharding, shape_determination_fns, &xla_shape));
151 }
152 if (use_tuple_args) {
153 xla_input_shapes->push_back(
154 xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
155 } else {
156 *xla_input_shapes = individual_arg_shapes;
157 }
158 return OkStatus();
159 }
160
161 // Returns a static ranked tensor type corresponding to the given static or
162 // bounded type by using the bounds as dimension sizes. Returns null if is
163 // neither.
GetBufferType(mlir::Type ty)164 mlir::RankedTensorType GetBufferType(mlir::Type ty) {
165 auto ranked_ty = ty.dyn_cast_or_null<mlir::RankedTensorType>();
166 if (!ranked_ty) return {};
167
168 int64_t rank = ranked_ty.getRank();
169 llvm::SmallVector<int64_t, 4> dims = llvm::to_vector<4>(ranked_ty.getShape());
170 auto encoding = ranked_ty.getEncoding()
171 .dyn_cast_or_null<mlir::mhlo::TypeExtensionsAttr>();
172 if (encoding && !encoding.getBounds().empty()) {
173 for (int64_t dim = 0; dim < rank; ++dim) {
174 if (dims[dim] == mlir::ShapedType::kDynamicSize) {
175 dims[dim] = encoding.getBounds()[dim];
176 }
177 }
178 }
179 return mlir::RankedTensorType::get(dims, ranked_ty.getElementType());
180 }
181
182 // Calculates computation output shape and build OutputDescription for each
183 // output based on static shapes in MLIR module. If an output is a resource
184 // write, `resource_updates` is populated insead of `outputs` for that output.
GetOutputInfo(mlir::ModuleOp module,bool use_resource_updates_for_aliases,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,xla::Shape * xla_output_shape,std::vector<XlaOutputDescription> * outputs,std::vector<XlaResourceUpdate> * resource_updates)185 Status GetOutputInfo(
186 mlir::ModuleOp module, bool use_resource_updates_for_aliases,
187 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
188 xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs,
189 std::vector<XlaResourceUpdate>* resource_updates) {
190 auto shape_representation_fn_no_fast_memory =
191 [shape_determination_fns](
192 const xla::Shape& xla_shape) -> StatusOr<xla::Shape> {
193 TensorShape shape;
194 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
195 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
196 xla_shape.element_type()));
197 auto layout_preference = shape_determination_fns.layout_preference_fn(
198 shape, dtype, std::nullopt);
199 return shape_determination_fns.shape_representation_fn(
200 shape, dtype, /*use_fast_memory=*/false, layout_preference);
201 };
202
203 mlir::func::FuncOp main_func =
204 module.lookupSymbol<mlir::func::FuncOp>("main");
205 mlir::FunctionType func_type = main_func.getFunctionType();
206
207 outputs->clear();
208 outputs->reserve(func_type.getNumResults());
209 resource_updates->clear();
210 resource_updates->reserve(func_type.getNumResults());
211
212 std::vector<xla::Shape> shapes;
213 shapes.reserve(func_type.getNumResults());
214
215 llvm::SmallDenseMap<unsigned, unsigned> output_to_input_alias;
216 for (unsigned i = 0; i < main_func.getNumArguments(); ++i)
217 if (auto aliasing_output = main_func.getArgAttrOfType<mlir::IntegerAttr>(
218 i, "tf.aliasing_output"))
219 output_to_input_alias[aliasing_output.getInt()] = i;
220
221 auto return_op = main_func.begin()->getTerminator();
222 for (const auto& type_and_idx : llvm::enumerate(func_type.getResults())) {
223 size_t idx = type_and_idx.index();
224 auto result_ty = type_and_idx.value().cast<mlir::RankedTensorType>();
225
226 // If the result type isn't static, then the owner of the result may be a
227 // cast op from a more specific bounded type to an unbounded dynamic type.
228 // Use the bounded type to get the buffer size.
229 mlir::RankedTensorType buffer_ty = result_ty;
230 if (!buffer_ty.hasStaticShape()) {
231 mlir::Value return_val = return_op->getOperand(idx);
232 if (auto owner = mlir::dyn_cast_or_null<mlir::tensor::CastOp>(
233 return_val.getDefiningOp())) {
234 // For bounded dynamic type, get a static size by taking bounds as the
235 // dimensions. These dimensions are marked as dynamic in xla::Shape
236 // below.
237 buffer_ty = GetBufferType(owner.getOperand().getType());
238 if (!buffer_ty || !buffer_ty.hasStaticShape()) {
239 return errors::InvalidArgument(
240 "results needs to be static or bounded");
241 }
242 }
243 }
244
245 xla::Shape shape = xla::TypeToShape(buffer_ty);
246 if (shape.element_type() == xla::PRIMITIVE_TYPE_INVALID) {
247 return errors::InvalidArgument("XLA conversion failed for MLIR type.");
248 }
249 TF_ASSIGN_OR_RETURN(shape, shape_representation_fn_no_fast_memory(shape));
250
251 if (!result_ty.hasStaticShape()) {
252 int64_t rank = result_ty.getRank();
253 for (int64_t dim = 0; dim < rank; ++dim) {
254 if (result_ty.isDynamicDim(dim)) {
255 shape.set_dynamic_dimension(dim, true);
256 }
257 }
258 }
259
260 auto sharding = main_func.getResultAttrOfType<mlir::StringAttr>(
261 type_and_idx.index(), "mhlo.sharding");
262 TF_RETURN_IF_ERROR(MaybeRewriteLayoutWithShardedShape(
263 sharding, shape_determination_fns, &shape));
264
265 auto tensor_type = type_and_idx.value().dyn_cast<mlir::RankedTensorType>();
266 shapes.push_back(shape);
267
268 auto it = output_to_input_alias.find(type_and_idx.index());
269 if (it != output_to_input_alias.end() && use_resource_updates_for_aliases) {
270 // Add resource write.
271 resource_updates->emplace_back();
272 XlaResourceUpdate& resource_update = resource_updates->back();
273 resource_update.input_index = it->getSecond();
274 resource_update.modified = true;
275 TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &resource_update.type));
276 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &resource_update.shape));
277 continue;
278 }
279 // Construct OutputDescription for result.
280 outputs->emplace_back();
281 XlaOutputDescription& out_desc = outputs->back();
282 TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
283 // TODO(ycao): Support constant output.
284 out_desc.is_constant = false;
285 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &out_desc.shape));
286 // Input_index is only meaningful for resource output. Setting it to
287 // meaningless value -1 for non resource outputs.
288 out_desc.input_index =
289 it != output_to_input_alias.end() ? it->getSecond() : -1;
290 // MLIR-based TF-Compiler bridge doesn't support tensorlist output yet.
291 // TODO(ycao): Support tensorlist-type output.
292 out_desc.is_tensor_list = false;
293 }
294
295 // XLA computation always uses Tuple shape.
296 *xla_output_shape = xla::ShapeUtil::MakeTupleShape(shapes);
297 return OkStatus();
298 }
299
300 // Creates a vector that maps from the parameters of the XLA computation to
301 // their original argument positions.
302 // MLIR-based TF-Compiler bridge doesn't have constant analysis yet, thus no
303 // inputs are known constants. Therefore, the input mapping between input to
304 // computation arguments is a trivial in-order 1-1 mapping.
305 // TODO(ycao): Support computation with compile-time constant, which requires
306 // non-trivial input mapping as implemented now.
GetInputMappingForMlir(int num_inputs,std::vector<int> * input_mapping)307 void GetInputMappingForMlir(int num_inputs, std::vector<int>* input_mapping) {
308 input_mapping->resize(num_inputs, 0);
309 std::iota(input_mapping->begin(), input_mapping->end(), 0);
310 }
311
RegisterDialects(mlir::DialectRegistry & registry)312 static void RegisterDialects(mlir::DialectRegistry& registry) {
313 mlir::RegisterAllTensorFlowDialects(registry);
314 mlir::mhlo::registerAllMhloDialects(registry);
315 }
316
317 // Checks if functions can be inlined after TF -> HLO legalization. Currently
318 // TPU's are supported, to follow the behavior of inlining functions via the
319 // Graph based bridge in the TPUCompile op kernel.
CanInlineFunctionsPostLegalization(llvm::StringRef device_type)320 bool CanInlineFunctionsPostLegalization(llvm::StringRef device_type) {
321 return device_type == DEVICE_TPU_XLA_JIT;
322 }
323
324 } // namespace
325
RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,mlir::ModuleOp module)326 Status RefineShapes(llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
327 mlir::ModuleOp module) {
328 auto producer_or = GetTfGraphProducerVersion(module);
329 if (!producer_or.ok()) return producer_or.status();
330 int64_t producer_version = producer_or.ValueOrDie();
331
332 llvm::SmallVector<int64_t, 16> shape_backing;
333 llvm::SmallVector<llvm::ArrayRef<int64_t>, 4> arg_shapes_copy;
334 {
335 // Convert arg_shapes to a mlir friendly format.
336 size_t count = 0;
337 for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
338 if (tensor_resource_shape.is_resource) continue;
339 count += tensor_resource_shape.shape.dims();
340 }
341 shape_backing.resize(count);
342 arg_shapes_copy.reserve(arg_shapes.size());
343 size_t offset = 0;
344 for (const TensorOrResourceShape& tensor_resource_shape : arg_shapes) {
345 if (tensor_resource_shape.is_resource) {
346 arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
347 continue;
348 }
349 size_t start = offset;
350 for (tensorflow::TensorShapeDim dim : tensor_resource_shape.shape) {
351 shape_backing[offset] = dim.size;
352 ++offset;
353 }
354 if (offset == start) {
355 arg_shapes_copy.push_back(llvm::ArrayRef<int64_t>());
356 } else {
357 arg_shapes_copy.push_back(
358 llvm::ArrayRef<int64_t>(&shape_backing[start], offset - start));
359 }
360 }
361 }
362
363 auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
364
365 mlir::StatusScopedDiagnosticHandler error_handler(module.getContext());
366 mlir::LogicalResult result = mlir::TF::InferShapeForFunction(
367 main_func, arg_shapes_copy, producer_version);
368
369 if (failed(result)) {
370 return error_handler.Combine(
371 errors::Internal("MLIR Shape refinement failed"));
372 }
373 return error_handler.ConsumeStatus();
374 }
375
CreateConvertMlirToXlaHloPipeline(mlir::OpPassManager & pm,llvm::StringRef device_type,bool prefer_tf2xla,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes,bool allow_partial_conversion)376 void CreateConvertMlirToXlaHloPipeline(
377 mlir::OpPassManager& pm, llvm::StringRef device_type, bool prefer_tf2xla,
378 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
379 custom_legalization_passes,
380 bool allow_partial_conversion) {
381 // Note that the region-based control-flow produced here still contains
382 // function call ops which get inlined by the subsequent inliner pass.
383 pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
384 pm.addPass(mlir::createInlinerPass());
385 pm.addNestedPass<mlir::func::FuncOp>(
386 mlir::TF::CreateDropWhileShapeInvariantPass());
387 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
388 // The SCCP pass performs constant propagation across the IR, which, for
389 // example, propagates constant arguments into callee functions.
390 // TOOD(hinsu): Investigate if we really need SCCP pass before shape inference
391 // and can do with just one pass after the shape inference.
392 pm.addPass(mlir::createSCCPPass());
393 // Guarantee all functions have one use, which enables shape inference.
394 pm.addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
395 // Run shape inference pass before tensorlist decomposition to get buffer
396 // shape of uninitialized TensorLists.
397 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
398
399 // Run SCCP pass again as the availability of shapes may open up new
400 // opportunities for constant propagation. Note that the shape inference pass
401 // doesn't materialize new constants even if those are computed internally for
402 // the purpose of shape inference. These constants might be required by the
403 // legalization passes.
404 pm.addPass(mlir::createSCCPPass());
405
406 pm.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
407 pm.addPass(mlir::TF::CreateStackOpsDecompositionPass());
408 pm.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
409 pm.addNestedPass<mlir::func::FuncOp>(
410 mlir::TFDevice::CreateDecomposeResourceOpsPass());
411 pm.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
412 pm.addPass(mlir::createSymbolDCEPass());
413 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
414 // TODO(b/171426148): We cannot completely remove region to functional control
415 // flow conversion from this pipeline yet as it causes some unit tests to
416 // fail.
417 pm.addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
418 // LegalizeTFControlFlow encapsulates arguments for control flow operations
419 // with a tuple argument which break the assumption of resource lifting
420 // inside PromoteResourcesToArgs.
421 pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
422
423 pm.addNestedPass<mlir::func::FuncOp>(mlir::TF::CreateLowerQuantizedPass());
424 pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
425 pm.addPass(mlir::mhlo::createLegalizeTFModulePass(
426 /*tf2xla_fallback_device_type=*/device_type));
427 pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeTFPass(
428 /*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
429 /*tf2xla_fallback_device_type=*/device_type, prefer_tf2xla));
430 for (auto& target_pass : custom_legalization_passes) {
431 pm.addNestedPass<mlir::func::FuncOp>(std::move(target_pass));
432 }
433 pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::CreateAdjustLayoutPass());
434 pm.addPass(mlir::mhlo::CreateLegalizeTFCommunicationPass());
435 pm.addPass(mlir::mhlo::CreateLegalizeTFCollectivePass());
436 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
437 // Run shape inference pass to propagate shapes through tensor_cast operations
438 // from static to dynamic shapes. This could be generated if the shape
439 // inference was originally missing in a TF op but the corresponding HLO op
440 // had static shape after lowering.
441 pm.addPass(mlir::TF::CreateTFShapeInferencePass());
442 // Run LegalizeTFPass again because the previous legalization passes can
443 // expose more graph pruning and canonicalization opportunities that are
444 // necessary for the second LegalizeTFPass(allow_partial_conversion=false)
445 // invocation.
446 pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeTFPass(
447 /*allow_partial_conversion=*/allow_partial_conversion,
448 /*legalize_chlo=*/true,
449 /*tf2xla_fallback_device_type=*/device_type, prefer_tf2xla));
450
451 if (CanInlineFunctionsPostLegalization(device_type))
452 pm.addPass(mlir::createInlinerPass());
453
454 // In order to export to XLA, we must sink constants to control flow regions,
455 // since XLA uses functional control flow.
456 pm.addNestedPass<mlir::func::FuncOp>(
457 mlir::mhlo::createSinkConstantsToControlFlowPass());
458 }
459
LegalizeToHlo(mlir::ModuleOp module_op,llvm::StringRef device_type,bool prefer_tf2xla,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)460 Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
461 bool prefer_tf2xla,
462 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
463 custom_legalization_passes) {
464 mlir::PassManager tf2xla(module_op.getContext());
465 applyTensorflowAndCLOptions(tf2xla);
466 CreateConvertMlirToXlaHloPipeline(tf2xla, device_type, prefer_tf2xla,
467 custom_legalization_passes);
468
469 if (VLOG_IS_ON(1))
470 tensorflow::DumpMlirOpToFile("legalize_hlo_before", module_op, "", &tf2xla);
471 if (VLOG_IS_ON(2)) {
472 // Print the whole module after each pass which requires disabling
473 // multi-threading as well.
474 module_op.getContext()->disableMultithreading();
475 tf2xla.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
476 /*print_module_scope=*/true));
477 }
478
479 // Make sure we catch any error reported by MLIR and forward it to the TF
480 // error reporting system. Report a generic error if pass manager failed
481 // without emitting a diagnostic.
482 mlir::StatusScopedDiagnosticHandler error_handler(module_op.getContext());
483
484 if (failed(tf2xla.run(module_op))) {
485 Status status = errors::InvalidArgument("TF to XLA legalization failed: ");
486 tensorflow::OkOrSetErrorCounterPayload(
487 tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_2,
488 status);
489 return error_handler.Combine(status);
490 }
491
492 if (VLOG_IS_ON(1))
493 tensorflow::DumpMlirOpToFile("legalize_hlo_after", module_op, "", &tf2xla);
494 Status status = error_handler.ConsumeStatus();
495 tensorflow::OkOrSetErrorCounterPayload(
496 tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_2,
497 status);
498 return status;
499 }
500
BuildHloFromTfInner(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)501 Status BuildHloFromTfInner(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
502 llvm::ArrayRef<xla::XlaOp> xla_params,
503 std::vector<xla::XlaOp>& returns,
504 llvm::StringRef device_type,
505 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
506 custom_legalization_passes) {
507 TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type,
508 /*prefer_tf2xla=*/false,
509 custom_legalization_passes));
510
511 mlir::Block& block =
512 module_op.lookupSymbol<mlir::func::FuncOp>("main").front();
513 return mlir::BuildHloFromMlirHlo(block, builder, xla_params, returns);
514 }
515
ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,llvm::StringRef device_type,xla::XlaComputation * xla_computation,bool use_tuple_args,bool prefer_tf2xla,bool return_tuple,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)516 Status ConvertMLIRToXlaComputation(
517 mlir::ModuleOp module_op, llvm::StringRef device_type,
518 xla::XlaComputation* xla_computation, bool use_tuple_args,
519 bool prefer_tf2xla, bool return_tuple,
520 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
521 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
522 custom_legalization_passes) {
523 TF_RETURN_IF_ERROR(LegalizeToHlo(module_op, device_type, prefer_tf2xla,
524 custom_legalization_passes));
525
526 mlir::MlirToHloConversionOptions options;
527 options.layout_preference_fn =
528 [&](const xla::Shape& xla_shape) -> StatusOr<mlir::XlaLayoutPreference> {
529 TensorShape shape;
530 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
531 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
532 xla_shape.element_type()));
533 return shape_determination_fns.layout_preference_fn(shape, dtype,
534 std::nullopt);
535 };
536 options.shape_representation_fn =
537 [&](const xla::Shape& xla_shape, bool fast_mem,
538 mlir::XlaLayoutPreference layout_preference) -> StatusOr<xla::Shape> {
539 TensorShape shape;
540 TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
541 TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
542 xla_shape.element_type()));
543 return shape_determination_fns.shape_representation_fn(
544 shape, dtype, fast_mem, layout_preference);
545 };
546 xla::HloProto hlo_proto;
547 TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(
548 module_op, &hlo_proto, use_tuple_args, return_tuple, options));
549 *xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
550 return OkStatus();
551 }
552
CompileMlirSetup(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes)553 Status CompileMlirSetup(mlir::ModuleOp module_op,
554 llvm::ArrayRef<TensorOrResourceShape> arg_shapes) {
555 // Use arg_shapes to improve the mlir type information of `main` in module_op.
556 TF_RETURN_IF_ERROR(RefineShapes(arg_shapes, module_op));
557
558 if (VLOG_IS_ON(2))
559 tensorflow::DumpMlirOpToFile("compile_mlir_shape_refiner", module_op);
560
561 return OkStatus();
562 }
563
BuildHloFromTf(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)564 Status BuildHloFromTf(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
565 llvm::ArrayRef<xla::XlaOp> xla_params,
566 std::vector<xla::XlaOp>& returns,
567 llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
568 llvm::StringRef device_type,
569 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
570 custom_legalization_passes) {
571 if (VLOG_IS_ON(2))
572 tensorflow::DumpMlirOpToFile("build_hlo_tf_before", module_op);
573
574 TF_RETURN_IF_ERROR(CompileMlirSetup(module_op, arg_shapes));
575
576 // Convert MLIR module to XLA HLO proto contained in XlaComputation.
577 TF_RETURN_IF_ERROR(BuildHloFromTfInner(module_op, builder, xla_params,
578 returns, device_type,
579 custom_legalization_passes));
580
581 if (VLOG_IS_ON(2))
582 tensorflow::DumpMlirOpToFile("build_hlo_tf_after", module_op);
583
584 return OkStatus();
585 }
586
PopulateCollectiveInfo(mlir::ModuleOp module_op,XlaCompilationResult * compilation_result)587 Status PopulateCollectiveInfo(mlir::ModuleOp module_op,
588 XlaCompilationResult* compilation_result) {
589 // The StringRef cast is necessary before cxx14.
590 mlir::IntegerAttr group_key_attr =
591 module_op->getAttrOfType<mlir::IntegerAttr>(
592 mlir::StringRef(kGroupKeyAttrName.data(), kGroupKeyAttrName.size()));
593 mlir::IntegerAttr group_size_attr =
594 module_op->getAttrOfType<mlir::IntegerAttr>(mlir::StringRef(
595 kGroupSizeAttrName.data(), kGroupSizeAttrName.size()));
596 if (group_key_attr == nullptr && group_size_attr == nullptr) {
597 // No CollectiveInfo is present.
598 return OkStatus();
599 }
600 DCHECK(group_key_attr != nullptr)
601 << "module attribute " << kGroupKeyAttrName
602 << " is required for CollectiveInfo but not found.";
603 DCHECK(group_size_attr != nullptr)
604 << "module attribute " << kGroupSizeAttrName
605 << " is required for CollectiveInfo but not found.";
606 int32_t group_key = group_key_attr.getInt();
607 int32_t group_size = group_size_attr.getInt();
608 VLOG(2) << "Populating CollectiveInfo: group_key=" << group_key
609 << " group_size=" << group_size;
610 compilation_result->collective_info = {group_key, group_size, 0};
611 return OkStatus();
612 }
613
PopulateResultIOInfo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,bool use_tuple_args,bool use_resource_updates_for_aliases,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result)614 Status PopulateResultIOInfo(
615 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
616 bool use_tuple_args, bool use_resource_updates_for_aliases,
617 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
618 XlaCompilationResult* compilation_result) {
619 // Construct mapping from XlaComputation's arg to input edges of execute
620 // node.
621 GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
622
623 // Compute all input shapes.
624 TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
625 shape_determination_fns,
626 &compilation_result->xla_input_shapes));
627
628 // Compute all output descriptions and resource writes
629 return GetOutputInfo(
630 module_op, use_resource_updates_for_aliases, shape_determination_fns,
631 &compilation_result->xla_output_shape, &compilation_result->outputs,
632 &compilation_result->resource_updates);
633 }
634
CompileMlirToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<TensorOrResourceShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,bool use_return_tuple,bool use_resource_updates_for_aliases,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)635 Status CompileMlirToXlaHlo(
636 mlir::ModuleOp module_op, llvm::ArrayRef<TensorOrResourceShape> arg_shapes,
637 llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
638 bool use_return_tuple, bool use_resource_updates_for_aliases,
639 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
640 XlaCompilationResult* compilation_result,
641 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
642 custom_legalization_passes) {
643 if (analyse_graph &&
644 GetMlirBridge2ndPhaseRolloutPolicy(module_op) ==
645 MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
646 return CompileToHloGraphAnalysisFailedError();
647 }
648
649 TF_RETURN_IF_ERROR(CompileMlirSetup(module_op, arg_shapes));
650
651 // Convert MLIR module to XLA HLO proto contained in XlaComputation.
652 compilation_result->computation = std::make_shared<xla::XlaComputation>();
653 TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
654 module_op, device_type, compilation_result->computation.get(),
655 use_tuple_args, analyse_graph, use_return_tuple, shape_determination_fns,
656 custom_legalization_passes));
657
658 TF_RETURN_IF_ERROR(PopulateCollectiveInfo(module_op, compilation_result));
659
660 return PopulateResultIOInfo(module_op, arg_shapes, use_tuple_args,
661 use_resource_updates_for_aliases,
662 shape_determination_fns, compilation_result);
663 }
664
CompileSerializedMlirToXlaHlo(llvm::StringRef mlir_module_string,llvm::ArrayRef<TensorShape> arg_shapes,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)665 Status CompileSerializedMlirToXlaHlo(
666 llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
667 llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
668 const XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
669 XlaCompilationResult* compilation_result,
670 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
671 custom_legalization_passes) {
672 mlir::DialectRegistry mlir_registry;
673 RegisterDialects(mlir_registry);
674 mlir::MLIRContext mlir_context(mlir_registry);
675 mlir::OwningOpRef<mlir::ModuleOp> mlir_module;
676
677 TF_RETURN_IF_ERROR(
678 DeserializeMlirModule(mlir_module_string, &mlir_context, &mlir_module));
679 llvm::SmallVector<TensorOrResourceShape, 4> tensor_or_resource_shapes;
680 tensor_or_resource_shapes.reserve(arg_shapes.size());
681 for (const auto& arg_shape : arg_shapes)
682 tensor_or_resource_shapes.push_back({arg_shape});
683 return CompileMlirToXlaHlo(
684 mlir_module.get(), tensor_or_resource_shapes, device_type, use_tuple_args,
685 analyse_graph, /*use_return_tuple=*/true,
686 /*use_resource_updates_for_aliases=*/false, shape_determination_fns,
687 compilation_result, custom_legalization_passes);
688 }
689
690 // Rewrites the given module with specified args. For each of the constant args,
691 // it gets inlined in the "main' function and the corresponding argument is
692 // removed from the signature. For resource args, their subtypes are populated.
693 // Returns the original indices for the other arguments on success.
RewriteWithArgs(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args)694 static StatusOr<std::vector<int>> RewriteWithArgs(
695 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args) {
696 mlir::func::FuncOp main_fn =
697 module_op.lookupSymbol<mlir::func::FuncOp>("main");
698 std::vector<int> params;
699
700 bool has_resource_args = false;
701 auto builder = mlir::OpBuilder(main_fn.getBody());
702 std::vector<int> args_to_erase;
703 for (int idx = 0; idx < args.size(); idx++) {
704 const XlaArgument& xla_arg = args[idx];
705 mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
706 if (xla_arg.kind == XlaArgument::kResource) {
707 mlir::Type element_type;
708 if (xla_arg.type == DT_INVALID) {
709 return errors::Unimplemented(absl::StrCat(
710 "Argument ", idx,
711 " is an uninitialized resource variable which is currently"
712 " unsupported in the MLIR-based TPU bridge"));
713 }
714 TF_RETURN_IF_ERROR(ConvertDataType(xla_arg.type, builder, &element_type));
715 TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
716 GetTensorShapeFromXlaArgument(xla_arg));
717 auto resource_shape = arg_shape.dim_sizes();
718 llvm::SmallVector<int64_t, 4> resource_subtype_shape(
719 resource_shape.begin(), resource_shape.end());
720 auto resource_subtype =
721 mlir::RankedTensorType::get(resource_subtype_shape, element_type);
722 auto resource_type =
723 mlir::TF::ResourceType::get({resource_subtype}, builder.getContext());
724
725 auto tensor_type = mlir_arg.getType().cast<mlir::TensorType>();
726 if (tensor_type.hasRank()) {
727 mlir_arg.setType(
728 mlir::RankedTensorType::get(tensor_type.getShape(), resource_type));
729 } else {
730 mlir_arg.setType(mlir::UnrankedTensorType::get(resource_type));
731 }
732 has_resource_args = true;
733 }
734 if (xla_arg.kind != XlaArgument::kConstant) {
735 params.push_back(idx);
736 continue;
737 }
738
739 TF_ASSIGN_OR_RETURN(auto value_attr,
740 ConvertTensor(xla_arg.constant_value, &builder));
741 // TODO(hinsu): Use the actual location of the constant.
742 auto constant = builder.create<mlir::TF::ConstOp>(
743 mlir::UnknownLoc::get(module_op.getContext()), value_attr);
744 mlir_arg.replaceAllUsesWith(constant);
745 args_to_erase.push_back(idx);
746 }
747
748 if (has_resource_args) {
749 llvm::SmallVector<mlir::Type, 4> updated_argument_types;
750 updated_argument_types.reserve(main_fn.getNumArguments());
751 for (mlir::BlockArgument& arg : main_fn.getArguments())
752 updated_argument_types.push_back(arg.getType());
753
754 main_fn.setType(
755 mlir::FunctionType::get(main_fn.getContext(), updated_argument_types,
756 main_fn.getFunctionType().getResults()));
757 }
758
759 for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx);
760
761 return params;
762 }
763
CompileGraphSetup(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,std::vector<int> * remaining_params,llvm::SmallVector<TensorOrResourceShape,4> & arg_shapes)764 Status CompileGraphSetup(
765 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
766 std::vector<int>* remaining_params,
767 llvm::SmallVector<TensorOrResourceShape, 4>& arg_shapes) {
768 TF_ASSIGN_OR_RETURN(*remaining_params, RewriteWithArgs(module_op, args));
769 arg_shapes.reserve(remaining_params->size());
770 for (unsigned idx : *remaining_params) {
771 const auto& arg = args[idx];
772 TF_ASSIGN_OR_RETURN(TensorShape arg_shape,
773 GetTensorShapeFromXlaArgument(arg));
774 arg_shapes.push_back({arg_shape,
775 /*is_resource=*/arg.kind == XlaArgument::kResource});
776 }
777
778 mlir::PassManager pm(module_op.getContext());
779 applyTensorflowAndCLOptions(pm);
780 mlir::TF::StandardPipelineOptions tf_options;
781 mlir::TF::CreateTFStandardPipeline(pm, tf_options);
782
783 if (VLOG_IS_ON(1))
784 tensorflow::DumpMlirOpToFile("compile_graph_setup_before", module_op);
785 if (VLOG_IS_ON(2)) {
786 // Print the whole module after each pass which requires disabling
787 // multi-threading as well.
788 module_op.getContext()->disableMultithreading();
789 pm.enableIRPrinting(std::make_unique<tensorflow::BridgeLoggerConfig>(
790 /*print_module_scope=*/true));
791 }
792
793 mlir::StatusScopedDiagnosticHandler diag_handler(module_op.getContext());
794 if (failed(pm.run(module_op))) return diag_handler.ConsumeStatus();
795 if (VLOG_IS_ON(1))
796 tensorflow::DumpMlirOpToFile("compile_graph_setup_after", module_op);
797
798 return OkStatus();
799 }
800
BuildHloFromModule(mlir::ModuleOp module_op,xla::XlaBuilder & builder,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)801 Status BuildHloFromModule(mlir::ModuleOp module_op, xla::XlaBuilder& builder,
802 llvm::ArrayRef<xla::XlaOp> xla_params,
803 std::vector<xla::XlaOp>& returns,
804 llvm::ArrayRef<XlaArgument> args,
805 llvm::StringRef device_type,
806 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
807 custom_legalization_passes) {
808 std::vector<int> remaining_params;
809 llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
810 TF_RETURN_IF_ERROR(
811 CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
812 // Passing down only remaining (non-constant) xla_params.
813 llvm::SmallVector<xla::XlaOp, 2> remaining_xla_params;
814 for (auto i : remaining_params) remaining_xla_params.push_back(xla_params[i]);
815 return BuildHloFromTf(module_op, builder, remaining_xla_params, returns,
816 arg_shapes, device_type, custom_legalization_passes);
817 }
818
CompileGraphToXlaHlo(mlir::ModuleOp module_op,llvm::ArrayRef<XlaArgument> args,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,bool use_return_tuple,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)819 Status CompileGraphToXlaHlo(
820 mlir::ModuleOp module_op, llvm::ArrayRef<XlaArgument> args,
821 llvm::StringRef device_type, bool use_tuple_args, bool analyse_graph,
822 bool use_return_tuple,
823 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
824 XlaCompilationResult* compilation_result,
825 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
826 custom_legalization_passes) {
827 std::vector<int> remaining_params;
828 llvm::SmallVector<TensorOrResourceShape, 4> arg_shapes;
829 TF_RETURN_IF_ERROR(
830 CompileGraphSetup(module_op, args, &remaining_params, arg_shapes));
831
832 auto status = CompileMlirToXlaHlo(
833 module_op, arg_shapes, device_type, use_tuple_args, analyse_graph,
834 use_return_tuple,
835 /*use_resource_updates_for_aliases=*/true, shape_determination_fns,
836 compilation_result, custom_legalization_passes);
837 compilation_result->input_mapping = remaining_params;
838 return status;
839 }
840
GraphToModule(const Graph & graph,llvm::ArrayRef<std::string> control_rets,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,mlir::MLIRContext * context)841 xla::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> GraphToModule(
842 const Graph& graph, llvm::ArrayRef<std::string> control_rets,
843 const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
844 mlir::MLIRContext* context) {
845 mlir::DialectRegistry registry;
846 RegisterDialects(registry);
847 context->appendDialectRegistry(registry);
848 GraphImportConfig config;
849 config.graph_as_function = true;
850 config.control_outputs = control_rets;
851 // Disable shape inference during import as some TensorFlow op fails during
852 // shape inference with dynamic shaped operands. This in turn causes the
853 // import to fail. Shape inference during import is going to be removed and
854 // the shape inference pass is run early in the pass pipeline, shape inference
855 // during import is not necessary.
856 config.enable_shape_inference = false;
857 return ConvertGraphToMlir(graph, debug_info, flib_def, config, context);
858 }
859
BuildHloFromGraph(const Graph & graph,xla::XlaBuilder & builder,mlir::MLIRContext & mlir_context,llvm::ArrayRef<xla::XlaOp> xla_params,std::vector<xla::XlaOp> & returns,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)860 Status BuildHloFromGraph(
861 const Graph& graph, xla::XlaBuilder& builder,
862 mlir::MLIRContext& mlir_context, llvm::ArrayRef<xla::XlaOp> xla_params,
863 std::vector<xla::XlaOp>& returns, llvm::ArrayRef<XlaArgument> args,
864 llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
865 const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
866 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
867 custom_legalization_passes) {
868 TF_ASSIGN_OR_RETURN(
869 mlir::OwningOpRef<mlir::ModuleOp> module,
870 GraphToModule(graph, control_rets, flib_def, debug_info, &mlir_context));
871 return BuildHloFromModule(module.get(), builder, xla_params, returns, args,
872 device_type, custom_legalization_passes);
873 }
874
CompileGraphToXlaHlo(const Graph & graph,llvm::ArrayRef<XlaArgument> args,llvm::ArrayRef<std::string> control_rets,llvm::StringRef device_type,bool use_tuple_args,bool analyse_graph,const FunctionLibraryDefinition & flib_def,const GraphDebugInfo & debug_info,XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,XlaCompilationResult * compilation_result,llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>> custom_legalization_passes)875 Status CompileGraphToXlaHlo(
876 const Graph& graph, llvm::ArrayRef<XlaArgument> args,
877 llvm::ArrayRef<std::string> control_rets, llvm::StringRef device_type,
878 bool use_tuple_args, bool analyse_graph,
879 const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
880 XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns,
881 XlaCompilationResult* compilation_result,
882 llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
883 custom_legalization_passes) {
884 mlir::MLIRContext context;
885 TF_ASSIGN_OR_RETURN(
886 mlir::OwningOpRef<mlir::ModuleOp> module,
887 GraphToModule(graph, control_rets, flib_def, debug_info, &context));
888 return CompileGraphToXlaHlo(
889 module.get(), args, device_type, use_tuple_args, analyse_graph,
890 /*use_return_tuple=*/true, shape_determination_fns, compilation_result,
891 custom_legalization_passes);
892 }
893
RegisterConvertMlirToXlaHloPipelineWithDefaults()894 void RegisterConvertMlirToXlaHloPipelineWithDefaults() {
895 static mlir::PassPipelineRegistration<> pipeline(
896 "tf-to-hlo-pipeline",
897 "Convert TF dialect to HLO dialect (used for compilation in bridge).",
898 [](mlir::OpPassManager& pm) {
899 tensorflow::CreateConvertMlirToXlaHloPipeline(
900 pm, /*device_type=*/"XLA_CPU_JIT", /*prefer_tf2xla=*/false,
901 /*custom_legalization_passes=*/{});
902 });
903 }
904
905 } // namespace tensorflow
906