xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/dtensor_mlir_passes.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
17 
18 #include <memory>
19 
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Transforms/Passes.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
28 #include "tensorflow/dtensor/cc/constants.h"
29 #include "tensorflow/dtensor/cc/dtensor_utils.h"
30 #include "tensorflow/dtensor/mlir/create_dtensor_mlir_passes.h"
31 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32 #include "tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.h"
33 
34 namespace tensorflow {
35 namespace dtensor {
36 namespace {
37 class ConditionalPrinter : public BridgeLoggerConfig {
38  private:
39   bool do_not_print_;
40 
41  public:
ConditionalPrinter(bool print_module_scope=false,bool print_after_only_on_change=true)42   explicit ConditionalPrinter(bool print_module_scope = false,
43                               bool print_after_only_on_change = true)
44       : BridgeLoggerConfig(print_module_scope, print_after_only_on_change) {
45     do_not_print_ = !(LogOnAllTasks() || (ClientId() == 0));
46   }
47 
printBeforeIfEnabled(mlir::Pass * pass,mlir::Operation * operation,PrintCallbackFn print_callback)48   void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
49                             PrintCallbackFn print_callback) override {}
50 
printAfterIfEnabled(mlir::Pass * pass,mlir::Operation * operation,PrintCallbackFn print_callback)51   void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
52                            PrintCallbackFn print_callback) override {
53     mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(operation);
54     if (!module) module = operation->getParentOfType<mlir::ModuleOp>();
55     if (module && !module->hasAttr(dtensor::kDoNotLog) && !do_not_print_)
56       BridgeLoggerConfig::printAfterIfEnabled(pass, operation, print_callback);
57   }
58 };
59 }  // namespace
60 
61 // Adds logger to DTensor transformation passmanager.
MaybeEnableLogging(mlir::PassManager * pm)62 bool MaybeEnableLogging(mlir::PassManager *pm) {
63   if (VLOG_IS_ON(1)) {
64     // Print the whole module after each pass, which requires disabling
65     // multi-threading as well.
66     pm->getContext()->disableMultithreading();
67     pm->enableIRPrinting(std::make_unique<ConditionalPrinter>(
68         /*print_module_scope=*/true));
69     return true;
70   }
71   return false;
72 }
73 
CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions & options,mlir::OpPassManager * pm)74 void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options,
75                            mlir::OpPassManager *pm) {
76   // Remove ops that cannot be reached from the sink node.
77   pm->addNestedPass<mlir::func::FuncOp>(
78       mlir::tf_executor::CreateTFExecutorGraphPruningPass());
79   // Remove graph-def executor dialect and represent IR as a flattened list of
80   // TF ops in functions.
81   pm->addNestedPass<mlir::func::FuncOp>(
82       mlir::CreateExecutorDialectToFunctionalConversionPass());
83 
84   // This does not guarantee that shape are inferred for all ops. For ops with
85   // dynamic shapes, shape information may still be missing.
86   pm->addPass(mlir::TF::CreateTFShapeInferencePass());
87 
88   // If V2 layout propagation algorithm, layouts are expressed as DTensorLayout
89   // op and Canonicalize and Inliner passes will not lose layout information.
90   pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorPropagateDefaultLayout());
91   pm->addPass(mlir::createSCCPPass());
92   pm->addPass(mlir::createCanonicalizerPass());
93   pm->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
94   pm->addPass(mlir::createInlinerPass());
95 
96   // Ensure that all functions have `device_id` as 0th argument.
97   pm->addPass(CreateDTensorPropagateDeviceIdToFunctionArgs());
98 
99   // Ensure that all functions with SparseTensor input is converted to its
100   // three component tensors and SparseToDenseOps are emitted for every usage
101   // of a SparseTensor.
102   pm->addPass(CreateDTensorSparseTensorToDenseTensor());
103 
104   AddDTensorEmbeddingPass(pm);
105 
106   // After shape inference, there may be unused constants ops added when
107   // propagating caller-callee constants. As DTensor mesh/layout propgation
108   // passes assumes that there are no unreachable ops, removes trivial unused
109   // ops. Note that `Canonicalizer` pass in TF includes similar optimization.
110   // However, canonicalizer pass also rewrites some ops and may remove `_layout`
111   // or `_mesh` attributes in the re-written TF ops.
112   // TODO(hongjunchoi): Remove this pass once shape inference pass no longer
113   // creates unnecessary constants ops.
114   pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
115 
116   // Canonicalization will merge tf.ConstOp from different DTensorLayout
117   // annotations, causing problem during mesh propagation. Undo the merge
118   // before creating clusters.
119   pm->addNestedPass<mlir::func::FuncOp>(
120       CreateDTensorUndoMergeConstAcrossMesh());
121 
122   // Propagate mesh cluster config and cluster ops by mesh cluster so that
123   // SPMD expansion can be isolated to a single device mesh.
124   pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorOpToDeviceClusterPass());
125   pm->addPass(CreateDTensorMeshPropagationPass());
126 
127   {
128     mlir::OpPassManager &func_pm = pm->nest<mlir::func::FuncOp>();
129     func_pm.addPass(CreateDTensorDeviceMeshClusterCoarsening());
130     // Set empty layout to cluster wrapping `tf.VarHandleOp`. VarHandle op
131     // always runs in the default device where client program executes.
132     func_pm.addPass(CreateDTensorDesignateResourceHandleMesh());
133   }
134 
135   // Validates that all cross mesh data transfers are expressed via
136   // DTensorLayout operation and lowers it to send/recvs.
137   pm->addPass(CreateDTensorHandleCrossClusterDependencies());
138 
139   // Mark all ops and functions with global shape attribute to preserve global
140   // shape information as it is needed during Layout Propagation and SPMD
141   // expansion.
142   pm->addPass(CreateDTensorAnnotateGlobalShape());
143 
144   // Propagate layout to all ops in graph.
145   pm->addPass(CreateDTensorMergeClustersPass());
146 
147   AddDTensorEmbeddingPassV2(pm);
148 
149   // For DTensor Checkpoint V2, the outputs of tf.RestoreV2 ops
150   // do not have shape information. We can infer the shapes of these
151   // outputs from the tf.AssignVariableOps that consume these outputs.
152   // This pass fills in all missing shapes caused by tf.RestoreV2 ops.
153   if (DTensorCheckpointV2Enabled()) {
154     pm->addPass(CreateDTensorInferShapesForRestoreV2Op());
155   }
156 
157   pm->addPass(CreateDTensorLayoutPropagationPassV2());
158 
159   // Expand graph to SPMD form given layouts are annotated to all ops.
160   // Remove all DTensorLayout ops after the expansion is done.
161   pm->addPass(CreateDTensorSPMDExpansion());
162 
163   // Insert functions to save or load embeddings when using tpu device.
164   AddDTensorEmbeddingCheckpointPass(pm);
165 
166   // Expand all ops that consume SparseTensors to possibly new ops.
167   // Remove any unused SparseToDense, Layout, and Const Ops after
168   // the expansion is done.
169   //
170   // Note that this pass assumes that SparseTensor operands is represented
171   // as an operand from the output of a SparseToDenseOp. Thus, this pass
172   // must happen after SparseTensorToDenseTensor pass and after
173   // the SPMD Expansion pass.
174   pm->addPass(CreateDTensorSparseExpansion());
175 
176   // Do a round of CSE: this helps reduce the number of consts in the graph now
177   // that SPMD expansion is done. We had replicated all Consts (so that each
178   // const only had one usage) as part of layout propagation.
179   pm->addPass(mlir::createCSEPass());
180 
181   // Lower the AllGather collectives. This has to happen before the all reduce
182   // optimizations and AllGather may emit an AllReduce.
183   pm->addPass(CreateDTensorAllGatherLoweringPass());
184 
185   // Fuses AllReduce and AllScatter into ReduceScatter.
186   if (!DoNotFuseReduceScatter()) {
187     pm->addNestedPass<mlir::func::FuncOp>(
188         CreateDTensorAllReduceScatterOptimization());
189   }
190 
191   // Changes order of DTensorAllReduce + Add to Add + DTensorAllReduce to
192   // minimize number of all reduce operations.
193   pm->addNestedPass<mlir::func::FuncOp>(
194       CreateDTensorAllReduceSumOptimization());
195 
196   AddDTensorAllReduceCombineOptimization(pm);
197 
198   // DTensorReduceScatter lowering should come before DTensorAllReduce
199   // and DTensorAllScatter lowerings since for some devices DTensorReduceScatter
200   // will be decomposed into an DTensorAllReduce+DTensorScatter.
201   pm->addPass(CreateDTensorReduceScatterLoweringPass());
202 
203   // For large enough reduction groups in reduction ops, upcast the input
204   // tensors to higher precision type (e.g. bfloat16 -> float32).
205   if (EnableMixedPrecisionReduce()) {
206     pm->addNestedPass<mlir::func::FuncOp>(
207         CreateDTensorMixedPrecisionReducePass());
208   }
209 
210   // Lower device-agnostic logical AllReduce ops into device-specific physical
211   // AllReduce ops.
212   //
213   // First, find DTensor collective ops such as DTensorAllReduce, which are
214   // generated by SPMD expansion. Lower them into device-specific forms. For
215   // most devices, there is a one-to-one mapping: DTensorAllReduce becomes
216   // CollectiveReduce on CPUs/GPUs and XlaAllReduce on TPU pods.
217   // Optionally, for special topologies, DTensorAllReduce
218   // could become a chain of collectives running on different devices:
219   // XlaAllReduce on each donut followed by CollectiveReduce on the hosts. Those
220   // collective ops running on hosts will have their _mesh attribute set to
221   // empty by this pass. The other ops continue to have no _mesh attributes,
222   // which means they run on the cluster mesh.
223   pm->addPass(CreateDTensorAllReduceLoweringPass());
224 
225   pm->addPass(CreateDTensorAllScatterLoweringPass());
226 
227   // Group together multiple device clusters assigned to the same mesh. Repeat
228   // this for every mesh to support multi-mesh. Collective lowering may have
229   // created multiple CPU mesh clusters for executing collective operations on
230   // CPUs.
231   // As so, we merge newly created CPU clusters after collective lowering
232   // especially for special topologies.
233   pm->addPass(CreateDTensorMergeClustersPass());
234   pm->addPass(CreateDTensorLowerSendRecv());
235 
236   // Convert tf_device.cluster into a function call op.
237   pm->addPass(mlir::TFDevice::CreateClusterOutliningPass());
238   pm->addPass(CreateDTensorClusterFunctionConversion());
239 
240   // During layout propagation, we clone all constants with multiple consumers
241   // for easier analaysis.
242   // This may create multiple same constants ops. Apply constant folding on
243   // duplicated constant operations to reduce graph size.
244   pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorConstantFolding());
245   // DTensor SPMD lowering passes may have created auxiliary operations that are
246   // no longer used. Add additional DCE pass to remove unused non-side effecting
247   // ops.
248   pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
249 
250   // DTensor SPMD Expansion may have caused multiple control flows and
251   // duplicate ops to calculate device ordinal. Re-run SCCP and merge
252   // controlflows if possible.
253   pm->addNestedPass<mlir::func::FuncOp>(mlir::createSCCPPass());
254   pm->addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
255   pm->addPass(mlir::TFDevice::CreateMergeControlFlowPass());
256 
257   // TF2XLA Integration
258   {
259     // Make sure clusters that run on TPU's are correct metadata ops and
260     // attributes attached to be compatible with later TPU specific optimization
261     // passes.
262     pm->addPass(CreateDTensorTPUIntegration());
263 
264     pm->addNestedPass<mlir::func::FuncOp>(
265         mlir::TFDevice::CreateDecomposeResourceOpsPass());
266     // Sink constant ops into cluster region as DecomposeResourceOpsPass() could
267     // lift constant out due to folding.
268     pm->addNestedPass<mlir::func::FuncOp>(
269         mlir::TFDevice::CreateClusterConstantSinkingPass());
270 
271     // Run another shape inference pass (and following DCE pass) because
272     // resource decomposition might have created new partial types.
273     pm->addPass(mlir::TF::CreateTFShapeInferencePass());
274     pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
275     pm->addPass(mlir::TFDevice::CreateResourceOpLiftingPass());
276     pm->addPass(mlir::TFDevice::CreateClusterOutliningPass());
277 
278     // Rename functions with unique names, to avoid collisions in the function
279     // library.
280     pm->addPass(CreateFunctionRenamingPass());
281 
282     // As DTensor SPMD expansion handles sharded inputs for model
283     // parallelism, we set input/output sharding to maximal sharding
284     // for inputs/outputs of the TPU computation.
285     pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorSetDefaultSharding());
286 
287     // Creates a pass that marks TPU cluster input-output pairs reading and
288     // writing to same resource variable as aliases.
289     pm->addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass());
290 
291     // Convert compilation and replication attributes to unified attributes
292     // expected by TPURewritePass.
293     pm->addNestedPass<mlir::func::FuncOp>(
294         mlir::TFTPU::CreateCanonicalizeCompileAndReplicateAttributesPass());
295     // Create TPU Compile and TPU Execute ops for each TPU devices.
296     pm->addPass(mlir::TFTPU::CreateTPURewritePass());
297     // Convert unified compilation and replication attributes back to legacy
298     // attributes for subsequent passes.
299     pm->addNestedPass<mlir::func::FuncOp>(
300         mlir::TFTPU::CreateConvertToLegacyCompileAndReplicateAttributesPass());
301 
302     // Add placeholder device attributes to resource arguments of TPU
303     // computation. This ensures the following
304     // CreateTPUMergeVariablesWithExecutePass correctly merges resource
305     // operations with TPUExecute op.
306     pm->addPass(CreateDTensorTpuAddResourceDeviceAttribute());
307     // Translate TPUExecute op to TPUExecuteAndUpdateVariable op to enable
308     // buffer aliasing.
309     pm->addPass(mlir::TFTPU::CreateTPUMergeVariablesWithExecutePass());
310 
311     pm->addPass(CreateDTensorUpdateTPUMetadata());
312     // If send/recv exists between TPU and CPU, then TPU Compilation program key
313     // is used as input for recv op in host computation as well as TPUExecute op
314     // in device computation. As so, move TPUCompile logic to host computation
315     // and transfer program key using send/recv operations.
316     pm->addPass(CreateDTensorMoveCompilationToHost());
317     pm->addPass(mlir::createSymbolDCEPass());
318   }
319 
320   pm->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
321 
322   // Convert graph into graph executor dialect so that transformed graph can be
323   // exported back to Graphdef.
324   pm->addNestedPass<mlir::func::FuncOp>(
325       mlir::CreateFunctionalToExecutorDialectConversionPass());
326   pm->addPass(mlir::CreateBreakUpIslandsPass());
327   pm->addNestedPass<mlir::func::FuncOp>(
328       mlir::TFDevice::CreateLaunchToDeviceAttributePass());
329   // Add additional BreakUpIslandPass as LaunchToDeviceAttribute pass may have
330   // created additional islands.
331   pm->addPass(mlir::CreateBreakUpIslandsPass());
332 }
333 
334 }  // namespace dtensor
335 }  // namespace tensorflow
336