1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_CREATE_DTENSOR_MLIR_PASSES_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_CREATE_DTENSOR_MLIR_PASSES_H_ 18 19 #include <memory> 20 21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 23 #include "mlir/Pass/Pass.h" // from @llvm-project 24 #include "mlir/Pass/PassManager.h" // from @llvm-project 25 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 namespace dtensor { 30 31 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 32 CreateDTensorOpToDeviceClusterPass(); 33 34 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 35 CreateDTensorDeviceMeshClusterCoarsening(); 36 37 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> CreateDTensorDCE(); 38 39 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 40 CreateDTensorUndoMergeConstAcrossMesh(); 41 42 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 43 CreateDTensorConstantFolding(); 44 45 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 46 CreateDTensorAllReduceSumOptimization(); 47 48 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 49 CreateDTensorAllReduceScatterOptimization(); 50 51 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 52 CreateDTensorAllReduceCombineOptimization(); 53 54 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 55 CreateDTensorMixedPrecisionReducePass(); 56 57 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 58 CreateDTensorSetDefaultSharding(); 59 60 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 61 CreateDTensorDesignateResourceHandleMesh(); 62 63 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> 64 CreateDTensorPropagateDefaultLayout(); 65 66 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 67 CreateDTensorHandleCrossClusterDependencies(); 68 69 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 70 CreateDTensorAnnotateGlobalShape(); 71 72 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 73 CreateDTensorLayoutPropagationPassV2(); 74 75 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 76 CreateDTensorMeshPropagationPass(); 77 78 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 79 CreateDTensorSPMDExpansion(); 80 81 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 82 CreateDTensorClusterFunctionConversion(); 83 84 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 85 CreateDTensorPropagateDeviceIdToFunctionArgs(); 86 87 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 88 CreateDTensorTPUIntegration(); 89 90 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 91 CreateDTensorTpuAddResourceDeviceAttribute(); 92 93 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 94 CreateDTensorUpdateTPUMetadata(); 95 96 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 97 CreateDTensorEmbeddingPass(); 98 99 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 100 CreateDTensorEmbeddingPassV2(); 101 102 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 103 CreateDTensorEmbeddingCheckpointPass(); 104 105 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 106 CreateFunctionRenamingPass(); 107 108 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 109 CreateDTensorAllReduceLoweringPass(); 110 111 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 112 CreateDTensorReduceScatterLoweringPass(); 113 114 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 115 CreateDTensorAllGatherLoweringPass(); 116 117 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 118 CreateDTensorAllScatterLoweringPass(); 119 120 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 121 CreateDTensorMergeClustersPass(); 122 123 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 124 CreateDTensorLowerSendRecv(); 125 126 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 127 CreateDTensorMoveCompilationToHost(); 128 129 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 130 CreateDTensorSparseTensorToDenseTensor(); 131 132 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 133 CreateDTensorSparseExpansion(); 134 135 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 136 CreateDTensorInferShapesForRestoreV2Op(); 137 138 // Generate the code for registering passes. 139 #define GEN_PASS_REGISTRATION 140 #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" 141 142 } // namespace dtensor 143 } // namespace tensorflow 144 145 #endif // TENSORFLOW_DTENSOR_MLIR_CREATE_DTENSOR_MLIR_PASSES_H_ 146