xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/restore_shape_inference.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/Support/FormatVariadic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
21 #include "mlir/Pass/Pass.h"  // from @llvm-project
22 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
24 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
25 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
26 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
27 #include "tensorflow/dtensor/mlir/shape_utils.h"
28 #include "tensorflow/dtensor/mlir/value_utils.h"
29 
30 namespace tensorflow {
31 namespace dtensor {
32 namespace {
33 
34 // From the Operation that produces `value`, set the result type to `type`.
35 //
36 // Recursively set the result type to `type` going backward toward
37 // the tf.RestoreV2Op that produced the unknown shape associated with `value`.
BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module,mlir::OpBuilder * builder,mlir::Value value,mlir::Type type)38 mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module,
39                                                       mlir::OpBuilder* builder,
40                                                       mlir::Value value,
41                                                       mlir::Type type) {
42   mlir::Operation* op = value.getDefiningOp();
43   if (op == nullptr) return mlir::success();
44   if (!llvm::isa<mlir::TF::IdentityOp, mlir::TF::DTensorRecv,
45                  mlir::TF::RestoreV2Op>(op)) {
46     return op->emitOpError(llvm::formatv(
47         "Expected an Identity, DTensorRecv, or RestoreV2 op, but got: {0}",
48         op->getName().getStringRef()));
49   }
50 
51   builder->setInsertionPointAfter(op);
52 
53   // Base case: If we got to the RestoreV2Op, then we got to the root
54   // of the unknown shape result. Set the type to `type` of the result index
55   // from `value`.
56   if (auto restore_op = llvm::dyn_cast_or_null<mlir::TF::RestoreV2Op>(op)) {
57     // This is usually a dangerous operation, but since we are backward
58     // propagating shapes and correctly setting the shapes backwards,
59     // we can modify the value itself here instead of creating a new
60     // RestoreV2 op.
61     //
62     // Creating a new RestoreV2 op and replacing all uses will make this
63     // algorithm run in O(N^2) where N = number of outputs of RestoreV2.
64     //
65     // Using setType(type) modifies in place and makes this algorithm run in
66     // O(N).
67     value.setType(type);
68   } else if (auto identity_op =
69                  llvm::dyn_cast_or_null<mlir::TF::IdentityOp>(op)) {
70     auto new_identity_op = builder->create<mlir::TF::IdentityOp>(
71         identity_op.getLoc(), type, identity_op.input());
72     identity_op.output().replaceAllUsesWith(new_identity_op.output());
73     identity_op.erase();
74 
75     // Recursively shape inference to the input of the identity op.
76     return BackwardShapeInferenceToRestoreOp(module, builder,
77                                              new_identity_op.input(), type);
78   } else if (auto recv_op = llvm::dyn_cast_or_null<mlir::TF::DTensorRecv>(op)) {
79     // If we have a DTensorRecv, then there is cross mesh action and the
80     // RestoreV2Op we want to fix is on the mesh of the corresponding
81     // DTensorSend. Set shape of this DTensorRecv first and go to the
82     // corresponding DTensorSend.
83     auto new_recv_op = builder->create<mlir::TF::DTensorRecv>(
84         recv_op.getLoc(), type, builder->getStringAttr(recv_op.key()),
85         mlir::TF::ShapeAttr::get(builder->getContext(),
86                                  type.dyn_cast<mlir::TensorType>()),
87         mlir::dtensor::LayoutAttr::get(builder->getContext(),
88                                        recv_op.layout()));
89 
90     recv_op.replaceAllUsesWith(new_recv_op.output());
91     recv_op.erase();
92 
93     auto send_op = GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorRecv>(
94         module, new_recv_op);
95 
96     if (!send_op.ok())
97       return recv_op.emitOpError(send_op.status().error_message());
98 
99     // Recursively shape inference to the input of the send op.
100     return BackwardShapeInferenceToRestoreOp(
101         module, builder, send_op.value()->getOperand(0), type);
102   }
103   return mlir::success();
104 }
105 
106 // From every AssignVariableOp, if the value X that we are assigning to the
107 // resource tensor has unknown shape information, then value X might be
108 // from the result of a tf.RestoreV2 op.
109 //
110 // We can infer the unknown shape of the result of a tf.RestoreV2 op through
111 // the resource tensors of AssignVariableOps that consume the results.
112 //
113 // Thus, we propagate the underlying resource tensor shape and dtype backwards
114 // leading up to the tf.RestoreV2 op.
PropagateShapeInformationFromAssignVariableOp(mlir::ModuleOp module)115 mlir::LogicalResult PropagateShapeInformationFromAssignVariableOp(
116     mlir::ModuleOp module) {
117   module.walk([&](mlir::TF::AssignVariableOp assign_op) {
118     // Check that the `value` has an unknown shape.
119     if (ValueRank(assign_op.value()) == -1) {
120       StatusOr<llvm::ArrayRef<int64_t>> shape =
121           GetShapeOfValue(assign_op.resource());
122       if (!shape.ok()) {
123         assign_op->emitOpError(
124             "Resource tensor was expected to have shape information but was "
125             "missing it during CheckpointShapeInference.");
126         return mlir::WalkResult::interrupt();
127       }
128       // Propagete shape backwards to all the ops that use or produce
129       // the value with missing shape.
130       mlir::OpBuilder builder(assign_op);
131       mlir::Type known_type = GetSubtypeOrSelf(assign_op.resource());
132       if (mlir::failed(BackwardShapeInferenceToRestoreOp(
133               module, &builder, assign_op.value(), known_type))) {
134         assign_op->emitOpError(
135             "Error doing Backward shape inference from AssignVariableOp during "
136             "CheckpointShapeInference.");
137         return mlir::WalkResult::interrupt();
138       }
139     }
140     return mlir::WalkResult::advance();
141   });
142 
143   return mlir::success();
144 }
145 
146 struct DTensorInferShapesForRestoreV2Op
147     : public DTensorInferShapesForRestoreV2OpBase<
148           DTensorInferShapesForRestoreV2Op> {
runOnOperationtensorflow::dtensor::__anond79076d20111::DTensorInferShapesForRestoreV2Op149   void runOnOperation() override {
150     auto module = getOperation();
151     if (failed(PropagateShapeInformationFromAssignVariableOp(module)))
152       return signalPassFailure();
153   };
154 };
155 
156 }  // namespace
157 
158 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorInferShapesForRestoreV2Op()159 CreateDTensorInferShapesForRestoreV2Op() {
160   return std::make_unique<DTensorInferShapesForRestoreV2Op>();
161 }
162 
163 }  // namespace dtensor
164 }  // namespace tensorflow
165