1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
17
18 #include <memory>
19
20 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Pass/PassManager.h" // from @llvm-project
24 #include "mlir/Transforms/Passes.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
28 #include "tensorflow/dtensor/cc/constants.h"
29 #include "tensorflow/dtensor/cc/dtensor_utils.h"
30 #include "tensorflow/dtensor/mlir/create_dtensor_mlir_passes.h"
31 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32 #include "tensorflow/dtensor/mlir/utils/dtensor_mlir_passes_internal.h"
33
34 namespace tensorflow {
35 namespace dtensor {
36 namespace {
37 class ConditionalPrinter : public BridgeLoggerConfig {
38 private:
39 bool do_not_print_;
40
41 public:
ConditionalPrinter(bool print_module_scope=false,bool print_after_only_on_change=true)42 explicit ConditionalPrinter(bool print_module_scope = false,
43 bool print_after_only_on_change = true)
44 : BridgeLoggerConfig(print_module_scope, print_after_only_on_change) {
45 do_not_print_ = !(LogOnAllTasks() || (ClientId() == 0));
46 }
47
printBeforeIfEnabled(mlir::Pass * pass,mlir::Operation * operation,PrintCallbackFn print_callback)48 void printBeforeIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
49 PrintCallbackFn print_callback) override {}
50
printAfterIfEnabled(mlir::Pass * pass,mlir::Operation * operation,PrintCallbackFn print_callback)51 void printAfterIfEnabled(mlir::Pass *pass, mlir::Operation *operation,
52 PrintCallbackFn print_callback) override {
53 mlir::ModuleOp module = mlir::dyn_cast<mlir::ModuleOp>(operation);
54 if (!module) module = operation->getParentOfType<mlir::ModuleOp>();
55 if (module && !module->hasAttr(dtensor::kDoNotLog) && !do_not_print_)
56 BridgeLoggerConfig::printAfterIfEnabled(pass, operation, print_callback);
57 }
58 };
59 } // namespace
60
61 // Adds logger to DTensor transformation passmanager.
MaybeEnableLogging(mlir::PassManager * pm)62 bool MaybeEnableLogging(mlir::PassManager *pm) {
63 if (VLOG_IS_ON(1)) {
64 // Print the whole module after each pass, which requires disabling
65 // multi-threading as well.
66 pm->getContext()->disableMultithreading();
67 pm->enableIRPrinting(std::make_unique<ConditionalPrinter>(
68 /*print_module_scope=*/true));
69 return true;
70 }
71 return false;
72 }
73
CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions & options,mlir::OpPassManager * pm)74 void CreateDTensorMLIRPass(const mlir::TF::StandardPipelineOptions &options,
75 mlir::OpPassManager *pm) {
76 // Remove ops that cannot be reached from the sink node.
77 pm->addNestedPass<mlir::func::FuncOp>(
78 mlir::tf_executor::CreateTFExecutorGraphPruningPass());
79 // Remove graph-def executor dialect and represent IR as a flattened list of
80 // TF ops in functions.
81 pm->addNestedPass<mlir::func::FuncOp>(
82 mlir::CreateExecutorDialectToFunctionalConversionPass());
83
84 // This does not guarantee that shape are inferred for all ops. For ops with
85 // dynamic shapes, shape information may still be missing.
86 pm->addPass(mlir::TF::CreateTFShapeInferencePass());
87
88 // If V2 layout propagation algorithm, layouts are expressed as DTensorLayout
89 // op and Canonicalize and Inliner passes will not lose layout information.
90 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorPropagateDefaultLayout());
91 pm->addPass(mlir::createSCCPPass());
92 pm->addPass(mlir::createCanonicalizerPass());
93 pm->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
94 pm->addPass(mlir::createInlinerPass());
95
96 // Ensure that all functions have `device_id` as 0th argument.
97 pm->addPass(CreateDTensorPropagateDeviceIdToFunctionArgs());
98
99 // Ensure that all functions with SparseTensor input is converted to its
100 // three component tensors and SparseToDenseOps are emitted for every usage
101 // of a SparseTensor.
102 pm->addPass(CreateDTensorSparseTensorToDenseTensor());
103
104 AddDTensorEmbeddingPass(pm);
105
106 // After shape inference, there may be unused constants ops added when
107 // propagating caller-callee constants. As DTensor mesh/layout propgation
108 // passes assumes that there are no unreachable ops, removes trivial unused
109 // ops. Note that `Canonicalizer` pass in TF includes similar optimization.
110 // However, canonicalizer pass also rewrites some ops and may remove `_layout`
111 // or `_mesh` attributes in the re-written TF ops.
112 // TODO(hongjunchoi): Remove this pass once shape inference pass no longer
113 // creates unnecessary constants ops.
114 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
115
116 // Canonicalization will merge tf.ConstOp from different DTensorLayout
117 // annotations, causing problem during mesh propagation. Undo the merge
118 // before creating clusters.
119 pm->addNestedPass<mlir::func::FuncOp>(
120 CreateDTensorUndoMergeConstAcrossMesh());
121
122 // Propagate mesh cluster config and cluster ops by mesh cluster so that
123 // SPMD expansion can be isolated to a single device mesh.
124 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorOpToDeviceClusterPass());
125 pm->addPass(CreateDTensorMeshPropagationPass());
126
127 {
128 mlir::OpPassManager &func_pm = pm->nest<mlir::func::FuncOp>();
129 func_pm.addPass(CreateDTensorDeviceMeshClusterCoarsening());
130 // Set empty layout to cluster wrapping `tf.VarHandleOp`. VarHandle op
131 // always runs in the default device where client program executes.
132 func_pm.addPass(CreateDTensorDesignateResourceHandleMesh());
133 }
134
135 // Validates that all cross mesh data transfers are expressed via
136 // DTensorLayout operation and lowers it to send/recvs.
137 pm->addPass(CreateDTensorHandleCrossClusterDependencies());
138
139 // Mark all ops and functions with global shape attribute to preserve global
140 // shape information as it is needed during Layout Propagation and SPMD
141 // expansion.
142 pm->addPass(CreateDTensorAnnotateGlobalShape());
143
144 // Propagate layout to all ops in graph.
145 pm->addPass(CreateDTensorMergeClustersPass());
146
147 AddDTensorEmbeddingPassV2(pm);
148
149 // For DTensor Checkpoint V2, the outputs of tf.RestoreV2 ops
150 // do not have shape information. We can infer the shapes of these
151 // outputs from the tf.AssignVariableOps that consume these outputs.
152 // This pass fills in all missing shapes caused by tf.RestoreV2 ops.
153 if (DTensorCheckpointV2Enabled()) {
154 pm->addPass(CreateDTensorInferShapesForRestoreV2Op());
155 }
156
157 pm->addPass(CreateDTensorLayoutPropagationPassV2());
158
159 // Expand graph to SPMD form given layouts are annotated to all ops.
160 // Remove all DTensorLayout ops after the expansion is done.
161 pm->addPass(CreateDTensorSPMDExpansion());
162
163 // Insert functions to save or load embeddings when using tpu device.
164 AddDTensorEmbeddingCheckpointPass(pm);
165
166 // Expand all ops that consume SparseTensors to possibly new ops.
167 // Remove any unused SparseToDense, Layout, and Const Ops after
168 // the expansion is done.
169 //
170 // Note that this pass assumes that SparseTensor operands is represented
171 // as an operand from the output of a SparseToDenseOp. Thus, this pass
172 // must happen after SparseTensorToDenseTensor pass and after
173 // the SPMD Expansion pass.
174 pm->addPass(CreateDTensorSparseExpansion());
175
176 // Do a round of CSE: this helps reduce the number of consts in the graph now
177 // that SPMD expansion is done. We had replicated all Consts (so that each
178 // const only had one usage) as part of layout propagation.
179 pm->addPass(mlir::createCSEPass());
180
181 // Lower the AllGather collectives. This has to happen before the all reduce
182 // optimizations and AllGather may emit an AllReduce.
183 pm->addPass(CreateDTensorAllGatherLoweringPass());
184
185 // Fuses AllReduce and AllScatter into ReduceScatter.
186 if (!DoNotFuseReduceScatter()) {
187 pm->addNestedPass<mlir::func::FuncOp>(
188 CreateDTensorAllReduceScatterOptimization());
189 }
190
191 // Changes order of DTensorAllReduce + Add to Add + DTensorAllReduce to
192 // minimize number of all reduce operations.
193 pm->addNestedPass<mlir::func::FuncOp>(
194 CreateDTensorAllReduceSumOptimization());
195
196 AddDTensorAllReduceCombineOptimization(pm);
197
198 // DTensorReduceScatter lowering should come before DTensorAllReduce
199 // and DTensorAllScatter lowerings since for some devices DTensorReduceScatter
200 // will be decomposed into an DTensorAllReduce+DTensorScatter.
201 pm->addPass(CreateDTensorReduceScatterLoweringPass());
202
203 // For large enough reduction groups in reduction ops, upcast the input
204 // tensors to higher precision type (e.g. bfloat16 -> float32).
205 if (EnableMixedPrecisionReduce()) {
206 pm->addNestedPass<mlir::func::FuncOp>(
207 CreateDTensorMixedPrecisionReducePass());
208 }
209
210 // Lower device-agnostic logical AllReduce ops into device-specific physical
211 // AllReduce ops.
212 //
213 // First, find DTensor collective ops such as DTensorAllReduce, which are
214 // generated by SPMD expansion. Lower them into device-specific forms. For
215 // most devices, there is a one-to-one mapping: DTensorAllReduce becomes
216 // CollectiveReduce on CPUs/GPUs and XlaAllReduce on TPU pods.
217 // Optionally, for special topologies, DTensorAllReduce
218 // could become a chain of collectives running on different devices:
219 // XlaAllReduce on each donut followed by CollectiveReduce on the hosts. Those
220 // collective ops running on hosts will have their _mesh attribute set to
221 // empty by this pass. The other ops continue to have no _mesh attributes,
222 // which means they run on the cluster mesh.
223 pm->addPass(CreateDTensorAllReduceLoweringPass());
224
225 pm->addPass(CreateDTensorAllScatterLoweringPass());
226
227 // Group together multiple device clusters assigned to the same mesh. Repeat
228 // this for every mesh to support multi-mesh. Collective lowering may have
229 // created multiple CPU mesh clusters for executing collective operations on
230 // CPUs.
231 // As so, we merge newly created CPU clusters after collective lowering
232 // especially for special topologies.
233 pm->addPass(CreateDTensorMergeClustersPass());
234 pm->addPass(CreateDTensorLowerSendRecv());
235
236 // Convert tf_device.cluster into a function call op.
237 pm->addPass(mlir::TFDevice::CreateClusterOutliningPass());
238 pm->addPass(CreateDTensorClusterFunctionConversion());
239
240 // During layout propagation, we clone all constants with multiple consumers
241 // for easier analaysis.
242 // This may create multiple same constants ops. Apply constant folding on
243 // duplicated constant operations to reduce graph size.
244 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorConstantFolding());
245 // DTensor SPMD lowering passes may have created auxiliary operations that are
246 // no longer used. Add additional DCE pass to remove unused non-side effecting
247 // ops.
248 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
249
250 // DTensor SPMD Expansion may have caused multiple control flows and
251 // duplicate ops to calculate device ordinal. Re-run SCCP and merge
252 // controlflows if possible.
253 pm->addNestedPass<mlir::func::FuncOp>(mlir::createSCCPPass());
254 pm->addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
255 pm->addPass(mlir::TFDevice::CreateMergeControlFlowPass());
256
257 // TF2XLA Integration
258 {
259 // Make sure clusters that run on TPU's are correct metadata ops and
260 // attributes attached to be compatible with later TPU specific optimization
261 // passes.
262 pm->addPass(CreateDTensorTPUIntegration());
263
264 pm->addNestedPass<mlir::func::FuncOp>(
265 mlir::TFDevice::CreateDecomposeResourceOpsPass());
266 // Sink constant ops into cluster region as DecomposeResourceOpsPass() could
267 // lift constant out due to folding.
268 pm->addNestedPass<mlir::func::FuncOp>(
269 mlir::TFDevice::CreateClusterConstantSinkingPass());
270
271 // Run another shape inference pass (and following DCE pass) because
272 // resource decomposition might have created new partial types.
273 pm->addPass(mlir::TF::CreateTFShapeInferencePass());
274 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorDCE());
275 pm->addPass(mlir::TFDevice::CreateResourceOpLiftingPass());
276 pm->addPass(mlir::TFDevice::CreateClusterOutliningPass());
277
278 // Rename functions with unique names, to avoid collisions in the function
279 // library.
280 pm->addPass(CreateFunctionRenamingPass());
281
282 // As DTensor SPMD expansion handles sharded inputs for model
283 // parallelism, we set input/output sharding to maximal sharding
284 // for inputs/outputs of the TPU computation.
285 pm->addNestedPass<mlir::func::FuncOp>(CreateDTensorSetDefaultSharding());
286
287 // Creates a pass that marks TPU cluster input-output pairs reading and
288 // writing to same resource variable as aliases.
289 pm->addPass(mlir::TFDevice::CreateMarkInputOutputAliasesPass());
290
291 // Convert compilation and replication attributes to unified attributes
292 // expected by TPURewritePass.
293 pm->addNestedPass<mlir::func::FuncOp>(
294 mlir::TFTPU::CreateCanonicalizeCompileAndReplicateAttributesPass());
295 // Create TPU Compile and TPU Execute ops for each TPU devices.
296 pm->addPass(mlir::TFTPU::CreateTPURewritePass());
297 // Convert unified compilation and replication attributes back to legacy
298 // attributes for subsequent passes.
299 pm->addNestedPass<mlir::func::FuncOp>(
300 mlir::TFTPU::CreateConvertToLegacyCompileAndReplicateAttributesPass());
301
302 // Add placeholder device attributes to resource arguments of TPU
303 // computation. This ensures the following
304 // CreateTPUMergeVariablesWithExecutePass correctly merges resource
305 // operations with TPUExecute op.
306 pm->addPass(CreateDTensorTpuAddResourceDeviceAttribute());
307 // Translate TPUExecute op to TPUExecuteAndUpdateVariable op to enable
308 // buffer aliasing.
309 pm->addPass(mlir::TFTPU::CreateTPUMergeVariablesWithExecutePass());
310
311 pm->addPass(CreateDTensorUpdateTPUMetadata());
312 // If send/recv exists between TPU and CPU, then TPU Compilation program key
313 // is used as input for recv op in host computation as well as TPUExecute op
314 // in device computation. As so, move TPUCompile logic to host computation
315 // and transfer program key using send/recv operations.
316 pm->addPass(CreateDTensorMoveCompilationToHost());
317 pm->addPass(mlir::createSymbolDCEPass());
318 }
319
320 pm->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
321
322 // Convert graph into graph executor dialect so that transformed graph can be
323 // exported back to Graphdef.
324 pm->addNestedPass<mlir::func::FuncOp>(
325 mlir::CreateFunctionalToExecutorDialectConversionPass());
326 pm->addPass(mlir::CreateBreakUpIslandsPass());
327 pm->addNestedPass<mlir::func::FuncOp>(
328 mlir::TFDevice::CreateLaunchToDeviceAttributePass());
329 // Add additional BreakUpIslandPass as LaunchToDeviceAttribute pass may have
330 // created additional islands.
331 pm->addPass(mlir::CreateBreakUpIslandsPass());
332 }
333
334 } // namespace dtensor
335 } // namespace tensorflow
336