xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/spmd_expander.h"
17 
18 #include <climits>
19 #include <cstdint>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/types/optional.h"
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/dtensor/cc/constants.h"
33 #include "tensorflow/dtensor/cc/dstatus.h"
34 #include "tensorflow/dtensor/cc/tensor_layout.h"
35 #include "tensorflow/dtensor/mlir/collectives.h"
36 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/op_utils.h"
39 #include "tensorflow/dtensor/mlir/shape_utils.h"
40 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41 #include "tensorflow/dtensor/proto/layout.pb.h"
42 
43 namespace tensorflow {
44 namespace dtensor {
45 
46 // static
Global()47 SPMDExpanderRegistry* SPMDExpanderRegistry::Global() {
48   static SPMDExpanderRegistry* registry = new SPMDExpanderRegistry();
49   return registry;
50 }
51 
GetPropagateFnForOp(mlir::Operation * op)52 SPMDExpanderBase* SPMDExpanderRegistry::GetPropagateFnForOp(
53     mlir::Operation* op) {
54   auto key = OpName(op);
55   auto fn = op_to_propagate_fn_map_.find(key);
56   if (fn == op_to_propagate_fn_map_.end()) return nullptr;
57   return fn->second.get();
58 }
59 
RegisterPropagateFn(std::string opName,std::unique_ptr<SPMDExpanderBase> prop)60 InitOnStartupMarker SPMDExpanderRegistry::RegisterPropagateFn(
61     std::string opName, std::unique_ptr<SPMDExpanderBase> prop) {
62   CHECK(op_to_propagate_fn_map_  // Crash ok
63             .insert_or_assign(opName, std::move(prop))
64             .second);
65   return {};
66 }
67 
ExpandOpAndSetLayout(mlir::Operation * op,mlir::Operation ** output)68 Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op,
69                                               mlir::Operation** output) {
70   TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> computed_layout,
71                       ExtractLayoutFromOp(op));
72 
73   if (computed_layout.empty() && op->getNumResults() != 0) {
74     return errors::InvalidArgument(
75         absl::StrCat("No attachced layout found for op : ", OpName(op),
76                      " This might be due to an error in layout propagation.")
77             .c_str());
78   }
79 
80   // `op` may be removed/replaced from the graph during SPMD expansion, so
81   // extract the global output shape before expansion.
82   llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> global_output_shapes;
83   global_output_shapes.reserve(op->getNumResults());
84   for (auto output_value : op->getResults()) {
85     auto maybe_ranked =
86         output_value.getType().dyn_cast<mlir::RankedTensorType>();
87     // Do not extract global shape if the shape isn't statically known.
88     //
89     // This is a bit subtle and relies on the check of static shape of output
90     // value below when extracting local_shape. We probably should consider a
91     // placeholder for unknown shapes to avoid surprises in the future.
92     //
93     // Given the nature of RestoreV2 op and its output ranks, we only special
94     // case for RestoreV2 for now.
95     if (llvm::isa<mlir::TF::RestoreV2Op, mlir::TF::DTensorRestoreV2Op>(op) &&
96         (!maybe_ranked || !maybe_ranked.hasStaticShape()))
97       continue;
98     TF_ASSIGN_OR_RETURN(auto global_shape,
99                         ExtractGlobalOutputShape(output_value));
100     global_output_shapes.emplace_back(llvm::SmallVector<int64_t, 4>{
101         global_shape.begin(), global_shape.end()});
102   }
103 
104   TF_ASSIGN_OR_RETURN(*output, this->ExpandOp(op));
105 
106   // TODO(hthu): Use ToString() instead.
107   SetLayoutOnOp(*output, absl::Span<absl::optional<Layout>>(
108                              computed_layout.data(), computed_layout.size()));
109 
110   // Verify the local shape of the expanded operation matches the shape expected
111   // from the layout. Note that this does **not** catch all errors. When tensor
112   // dimension is sharded in a wrong mesh with the same device cardinality as
113   // the correct/expected mesh, this check will still pass.
114   for (const auto& output_layout_and_index :
115        llvm::enumerate(llvm::zip((*output)->getResults(), computed_layout))) {
116     const int index = output_layout_and_index.index();
117     const auto& output_and_layout = output_layout_and_index.value();
118 
119     auto output_value = std::get<0>(output_and_layout);
120     // Extract the static shape of `output_value` if possible, otherwise ignore
121     // this output.
122     auto local_expanded_shape_or_status = GetShapeOfValue(output_value);
123     if (!local_expanded_shape_or_status.ok()) continue;
124 
125     const auto local_expanded_shape =
126         local_expanded_shape_or_status.ValueOrDie();
127     const auto& layout = std::get<1>(output_and_layout);
128     const auto expected_global_shape =
129         layout->GlobalShapeFromLocalShape(local_expanded_shape);
130 
131     for (const auto& expanded_and_true_global_shape :
132          llvm::zip(global_output_shapes[index], expected_global_shape)) {
133       const auto expanded_shape = std::get<0>(expanded_and_true_global_shape);
134       const auto expected_shape = std::get<1>(expanded_and_true_global_shape);
135       // If any of the shape has unknown dimension, do not check/validate the
136       // shape.
137       if (expanded_shape <= 0 || expected_shape <= 0) continue;
138 
139       if (expanded_shape != expected_shape) {
140         return errors::Internal(
141             "SPMD expansion resulted in op output inconsistent with the "
142             "provided layout.");
143       }
144     }
145   }
146 
147   return OkStatus();
148 }
149 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)150 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward(
151     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
152   return errors::Unimplemented(
153       "ComputeLayoutForward API must be implemented via the subclass.");
154 }
155 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)156 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward(
157     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
158     const llvm::DenseMap<int, Layout>& output_layouts) {
159   return ComputeLayoutForward(op, input_layouts);
160 }
161 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)162 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward(
163     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
164   return errors::Unimplemented(
165       "ComputeLayoutBackward API must be implemented via the subclass.");
166 }
167 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)168 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward(
169     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
170     const llvm::DenseMap<int, Layout>& output_layouts) {
171   return ComputeLayoutBackward(op, output_layouts);
172 }
173 
RunSPMDExpansion(mlir::Operation * op,mlir::Operation ** output)174 Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) {
175   SPMDExpanderBase* expander =
176       SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
177   if (expander != nullptr) {
178     return expander->ExpandOpAndSetLayout(op, output);
179   } else {
180     VLOG(1) << "No expansion found for " << OpName(op) << "\n";
181     *output = op;
182   }
183   return OkStatus();
184 }
185 
186 }  // namespace dtensor
187 }  // namespace tensorflow
188