xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/sparse_expansion.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <optional>
18 
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassManager.h"  // from @llvm-project
24 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
25 #include "mlir/Transforms/Passes.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
28 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
29 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
30 #include "tensorflow/dtensor/mlir/op_utils.h"
31 #include "tensorflow/dtensor/mlir/sparse_expander.h"
32 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
33 
34 namespace tensorflow {
35 namespace dtensor {
36 namespace {
37 
38 constexpr char kMainFunctionName[] = "main";
39 
40 // Expand every op that consumes SparseTensor operands in topological order.
ConductSparseExpansion(mlir::ModuleOp module)41 mlir::LogicalResult ConductSparseExpansion(mlir::ModuleOp module) {
42   auto main_func = module.lookupSymbol<mlir::func::FuncOp>(kMainFunctionName);
43   if (!main_func)
44     return module.emitOpError(
45         "could not find `main` function in module for SPMD expansion.");
46 
47   TopologicalIterator iterator(main_func);
48   while (iterator.hasNext()) {
49     mlir::Operation* op = iterator.next();
50 
51     mlir::Operation* expanded_op = nullptr;
52     auto status = RunSparseExpansion(op, &expanded_op);
53     if (!status.ok() || expanded_op == nullptr) {
54       // Sometimes op may been erased and expanded_op set.
55       // In this case we should emit the error on the expanded op.
56       mlir::Operation* emit_op = op;
57       if (expanded_op != nullptr) emit_op = expanded_op;
58       return emit_op->emitError(WithContext(status, __FILE__, __LINE__,
59                                             "While computing Sparse expansion")
60                                     .error_message());
61     }
62   }
63   return mlir::success();
64 }
65 
66 // After Sparse Expansion pass, there may be unused SparseToDenseOps due to
67 // expanded ops possibly taking the operands of the SparseToDenseOps instead
68 // of the output of the SparseToDenseOps. So remove unused SparseToDenseOps
69 // and its corresponding dependent ops like DTensorLayout and Const ops.
RemoveUnusedSparseToDenseOps(mlir::ModuleOp module)70 void RemoveUnusedSparseToDenseOps(mlir::ModuleOp module) {
71   llvm::SmallVector<mlir::TF::SparseToDenseOp, 4> sparse_ops_to_erase;
72   llvm::SmallVector<mlir::TF::DTensorLayout, 4> layout_ops_to_erase;
73 
74   module.walk([&](mlir::TF::SparseToDenseOp op) {
75     // Delete this op if it either has no consuming ops or the only consuming
76     // op is a DTensorLayout op that also has no consuming ops.
77     if (op->use_empty()) {
78       sparse_ops_to_erase.emplace_back(op);
79     } else if (op->hasOneUse()) {
80       if (auto layout_op = mlir::dyn_cast<mlir::TF::DTensorLayout>(
81               op->getOpResult(0).getUses().begin().getUser())) {
82         if (layout_op.use_empty()) {
83           layout_ops_to_erase.emplace_back(layout_op);
84           sparse_ops_to_erase.emplace_back(op);
85         }
86       }
87     }
88   });
89 
90   // First delete Layout ops and then delete SparseToDense ops.
91   for (auto op : layout_ops_to_erase) op.erase();
92   for (auto op : sparse_ops_to_erase) {
93     // Also delete the corresponding Const ops that are no longer used
94     // attached to the SparseToDense ops.
95     auto const_op = op.getOperand(3).getDefiningOp();
96     op.erase();
97     if (const_op->use_empty()) const_op->erase();
98   }
99 }
100 
101 struct DTensorSparseExpansion
102     : public DTensorSparseExpansionBase<DTensorSparseExpansion> {
runOnOperationtensorflow::dtensor::__anon7d671fb20111::DTensorSparseExpansion103   void runOnOperation() override {
104     auto module = getOperation();
105     if (failed(ConductSparseExpansion(module))) return signalPassFailure();
106 
107     // After Sparse Expansion, we may no longer use any SparseToDenseOp outputs,
108     // so remove them if they are not used.
109     RemoveUnusedSparseToDenseOps(module);
110   };
111 };
112 
113 }  // namespace
114 
115 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorSparseExpansion()116 CreateDTensorSparseExpansion() {
117   return std::make_unique<DTensorSparseExpansion>();
118 }
119 
120 }  // namespace dtensor
121 }  // namespace tensorflow
122