xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/create_dtensor_mlir_passes.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_MLIR_CREATE_DTENSOR_MLIR_PASSES_H_
17 #define TENSORFLOW_DTENSOR_MLIR_CREATE_DTENSOR_MLIR_PASSES_H_
18 
19 #include <memory>
20 
21 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassManager.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 namespace dtensor {
30 
31 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
32 CreateDTensorOpToDeviceClusterPass();
33 
34 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
35 CreateDTensorDeviceMeshClusterCoarsening();
36 
37 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateDTensorDCE();
38 
39 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
40 CreateDTensorUndoMergeConstAcrossMesh();
41 
42 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
43 CreateDTensorConstantFolding();
44 
45 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
46 CreateDTensorAllReduceSumOptimization();
47 
48 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
49 CreateDTensorAllReduceScatterOptimization();
50 
51 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
52 CreateDTensorAllReduceCombineOptimization();
53 
54 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
55 CreateDTensorMixedPrecisionReducePass();
56 
57 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
58 CreateDTensorSetDefaultSharding();
59 
60 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
61 CreateDTensorDesignateResourceHandleMesh();
62 
63 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
64 CreateDTensorPropagateDefaultLayout();
65 
66 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
67 CreateDTensorHandleCrossClusterDependencies();
68 
69 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
70 CreateDTensorAnnotateGlobalShape();
71 
72 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
73 CreateDTensorLayoutPropagationPassV2();
74 
75 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
76 CreateDTensorMeshPropagationPass();
77 
78 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
79 CreateDTensorSPMDExpansion();
80 
81 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
82 CreateDTensorClusterFunctionConversion();
83 
84 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
85 CreateDTensorPropagateDeviceIdToFunctionArgs();
86 
87 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
88 CreateDTensorTPUIntegration();
89 
90 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
91 CreateDTensorTpuAddResourceDeviceAttribute();
92 
93 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
94 CreateDTensorUpdateTPUMetadata();
95 
96 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
97 CreateDTensorEmbeddingPass();
98 
99 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
100 CreateDTensorEmbeddingPassV2();
101 
102 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
103 CreateDTensorEmbeddingCheckpointPass();
104 
105 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
106 CreateFunctionRenamingPass();
107 
108 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
109 CreateDTensorAllReduceLoweringPass();
110 
111 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
112 CreateDTensorReduceScatterLoweringPass();
113 
114 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
115 CreateDTensorAllGatherLoweringPass();
116 
117 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
118 CreateDTensorAllScatterLoweringPass();
119 
120 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
121 CreateDTensorMergeClustersPass();
122 
123 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
124 CreateDTensorLowerSendRecv();
125 
126 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
127 CreateDTensorMoveCompilationToHost();
128 
129 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
130 CreateDTensorSparseTensorToDenseTensor();
131 
132 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
133 CreateDTensorSparseExpansion();
134 
135 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
136 CreateDTensorInferShapesForRestoreV2Op();
137 
138 // Generate the code for registering passes.
139 #define GEN_PASS_REGISTRATION
140 #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
141 
142 }  // namespace dtensor
143 }  // namespace tensorflow
144 
145 #endif  // TENSORFLOW_DTENSOR_MLIR_CREATE_DTENSOR_MLIR_PASSES_H_
146