xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/dtensor_graph_to_mlir_pass.h"
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/Builders.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/Dialect.h"  // from @llvm-project
23 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
24 #include "mlir/IR/Types.h"  // from @llvm-project
25 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
30 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
31 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
32 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
35 #include "tensorflow/compiler/xla/status_macros.h"
36 #include "tensorflow/core/common_runtime/optimization_registry.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/tensor.pb.h"
39 #include "tensorflow/core/graph/algorithm.h"
40 #include "tensorflow/core/platform/errors.h"
41 #include "tensorflow/core/util/dump_graph.h"
42 #include "tensorflow/dtensor/cc/constants.h"
43 #include "tensorflow/dtensor/cc/dtensor_utils.h"
44 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
45 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
46 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
47 
48 namespace tensorflow {
49 
DTensorMlirPassRunner()50 DTensorMlirPassRunner::DTensorMlirPassRunner()
51     : pass_manager_(&context_), logging_enabled_(false) {
52   logging_enabled_ = dtensor::MaybeEnableLogging(&pass_manager_);
53   if (logging_enabled_) pass_manager_.getContext()->enableMultithreading();
54 
55   // TODO(hinsu, hongjunchoi): Figure out a better place to explicitly enable
56   // the MLIR bridge.
57   // Explicitly enable MLIR bridge as DTensor introduces some ops like
58   // XlaAllReduce are only supported in MLIR.
59   GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge =
60       ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
61 
62   // Creates a pipeline that include each DTensor related passes.
63   mlir::TF::StandardPipelineOptions pipeline_options;
64   dtensor::CreateDTensorMLIRPass(pipeline_options, &pass_manager_);
65 }
66 
RunOnGraph(const DeviceSet & device_set,bool is_func,FunctionLibraryDefinition * flib_def,std::unique_ptr<Graph> * graph,absl::flat_hash_set<Node * > & control_ret_nodes,Fprint128 cache_key)67 Status DTensorMlirPassRunner::RunOnGraph(
68     const DeviceSet& device_set, bool is_func,
69     FunctionLibraryDefinition* flib_def, std::unique_ptr<Graph>* graph,
70     absl::flat_hash_set<Node*>& control_ret_nodes, Fprint128 cache_key) {
71   Graph* input_graph = graph->get();
72   GraphDebugInfo debug_info;
73   GraphImportConfig import_config;
74   import_config.graph_as_function = true;
75   // DTensor relies on importing with shape_inference to work properly ATM.
76   // Make it explicit so that we're not affected by potential flipping of the
77   // flag.
78   import_config.enable_shape_inference = true;
79   // Graph pruning will prune away an op (may be side effecting) if the op is
80   // not reachable from a fetch/result or target/control ret. With how the entry
81   // function/Graph is created, it is possible if the op has no data results. To
82   // make sure this op does not get pruned away, the op is defined as a
83   // target/control ret.
84   import_config.control_outputs = {"eager_operation"};
85 
86   // Import GraphDef to TF MLIR.
87   stream_executor::port::StatusOr<mlir::OwningOpRef<mlir::ModuleOp>>
88       module_ref = ConvertGraphToMlir(*input_graph, debug_info, *flib_def,
89                                       import_config, &context_);
90   if (!module_ref.ok())
91     return errors::InvalidArgument(
92         absl::StrCat(
93             "Can not convert the graph to MLIR, errors from MLIR converter : ",
94             module_ref.status().error_message())
95             .c_str());
96 
97   mlir::ModuleOp module = module_ref.ValueOrDie().get();
98 
99   AddDevicesToOp(module, &device_set);
100 
101   // Tag the module for logging or not depending on flag.
102   if (!is_func && !dtensor::LogOpByOp())
103     module->setAttr(dtensor::kDoNotLog, mlir::UnitAttr::get(&context_));
104 
105   // Set the cache key for the module as an attribute. This attribute will be
106   // used to rename all private functions in the module (by appending the
107   // cache key) so they have unique names.
108   module->setAttr(
109       dtensor::kCacheKey,
110       mlir::StringAttr::get(&context_, absl::StrCat("_", cache_key.low64, "_",
111                                                     cache_key.high64)));
112 
113   // Executes and collects results from the passes.
114   mlir::StatusScopedDiagnosticHandler diag_handler(&context_);
115 
116   if (logging_enabled_ && !module->hasAttr(dtensor::kDoNotLog))
117     pass_manager_.getContext()->disableMultithreading();
118   mlir::LogicalResult result = pass_manager_.run(module);
119   (void)result;
120   TF_RETURN_IF_ERROR(diag_handler.ConsumeStatus());
121 
122   if (logging_enabled_) pass_manager_.getContext()->enableMultithreading();
123 
124   // Convert MLIR to graphdef for execution.
125   GraphExportConfig export_config;
126   TF_RETURN_WITH_CONTEXT_IF_ERROR(
127       ConvertMlirToGraph(module, export_config, graph, flib_def,
128                          &control_ret_nodes),
129       "Error converting MLIR module back to graph");
130   Graph* output_graph = graph->get();
131   VLOG(4) << DumpGraphToFile("dtensor_mlir_pass_after", *output_graph,
132                              flib_def);
133   return OkStatus();
134 }
135 
136 }  // namespace tensorflow
137