1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/spmd_expander.h"
17
18 #include <climits>
19 #include <cstdint>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/types/optional.h"
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/OperationSupport.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/core/platform/errors.h"
32 #include "tensorflow/dtensor/cc/constants.h"
33 #include "tensorflow/dtensor/cc/dstatus.h"
34 #include "tensorflow/dtensor/cc/tensor_layout.h"
35 #include "tensorflow/dtensor/mlir/collectives.h"
36 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/op_utils.h"
39 #include "tensorflow/dtensor/mlir/shape_utils.h"
40 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41 #include "tensorflow/dtensor/proto/layout.pb.h"
42
43 namespace tensorflow {
44 namespace dtensor {
45
46 // static
Global()47 SPMDExpanderRegistry* SPMDExpanderRegistry::Global() {
48 static SPMDExpanderRegistry* registry = new SPMDExpanderRegistry();
49 return registry;
50 }
51
GetPropagateFnForOp(mlir::Operation * op)52 SPMDExpanderBase* SPMDExpanderRegistry::GetPropagateFnForOp(
53 mlir::Operation* op) {
54 auto key = OpName(op);
55 auto fn = op_to_propagate_fn_map_.find(key);
56 if (fn == op_to_propagate_fn_map_.end()) return nullptr;
57 return fn->second.get();
58 }
59
RegisterPropagateFn(std::string opName,std::unique_ptr<SPMDExpanderBase> prop)60 InitOnStartupMarker SPMDExpanderRegistry::RegisterPropagateFn(
61 std::string opName, std::unique_ptr<SPMDExpanderBase> prop) {
62 CHECK(op_to_propagate_fn_map_ // Crash ok
63 .insert_or_assign(opName, std::move(prop))
64 .second);
65 return {};
66 }
67
ExpandOpAndSetLayout(mlir::Operation * op,mlir::Operation ** output)68 Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op,
69 mlir::Operation** output) {
70 TF_ASSIGN_OR_RETURN(std::vector<absl::optional<Layout>> computed_layout,
71 ExtractLayoutFromOp(op));
72
73 if (computed_layout.empty() && op->getNumResults() != 0) {
74 return errors::InvalidArgument(
75 absl::StrCat("No attachced layout found for op : ", OpName(op),
76 " This might be due to an error in layout propagation.")
77 .c_str());
78 }
79
80 // `op` may be removed/replaced from the graph during SPMD expansion, so
81 // extract the global output shape before expansion.
82 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> global_output_shapes;
83 global_output_shapes.reserve(op->getNumResults());
84 for (auto output_value : op->getResults()) {
85 auto maybe_ranked =
86 output_value.getType().dyn_cast<mlir::RankedTensorType>();
87 // Do not extract global shape if the shape isn't statically known.
88 //
89 // This is a bit subtle and relies on the check of static shape of output
90 // value below when extracting local_shape. We probably should consider a
91 // placeholder for unknown shapes to avoid surprises in the future.
92 //
93 // Given the nature of RestoreV2 op and its output ranks, we only special
94 // case for RestoreV2 for now.
95 if (llvm::isa<mlir::TF::RestoreV2Op, mlir::TF::DTensorRestoreV2Op>(op) &&
96 (!maybe_ranked || !maybe_ranked.hasStaticShape()))
97 continue;
98 TF_ASSIGN_OR_RETURN(auto global_shape,
99 ExtractGlobalOutputShape(output_value));
100 global_output_shapes.emplace_back(llvm::SmallVector<int64_t, 4>{
101 global_shape.begin(), global_shape.end()});
102 }
103
104 TF_ASSIGN_OR_RETURN(*output, this->ExpandOp(op));
105
106 // TODO(hthu): Use ToString() instead.
107 SetLayoutOnOp(*output, absl::Span<absl::optional<Layout>>(
108 computed_layout.data(), computed_layout.size()));
109
110 // Verify the local shape of the expanded operation matches the shape expected
111 // from the layout. Note that this does **not** catch all errors. When tensor
112 // dimension is sharded in a wrong mesh with the same device cardinality as
113 // the correct/expected mesh, this check will still pass.
114 for (const auto& output_layout_and_index :
115 llvm::enumerate(llvm::zip((*output)->getResults(), computed_layout))) {
116 const int index = output_layout_and_index.index();
117 const auto& output_and_layout = output_layout_and_index.value();
118
119 auto output_value = std::get<0>(output_and_layout);
120 // Extract the static shape of `output_value` if possible, otherwise ignore
121 // this output.
122 auto local_expanded_shape_or_status = GetShapeOfValue(output_value);
123 if (!local_expanded_shape_or_status.ok()) continue;
124
125 const auto local_expanded_shape =
126 local_expanded_shape_or_status.ValueOrDie();
127 const auto& layout = std::get<1>(output_and_layout);
128 const auto expected_global_shape =
129 layout->GlobalShapeFromLocalShape(local_expanded_shape);
130
131 for (const auto& expanded_and_true_global_shape :
132 llvm::zip(global_output_shapes[index], expected_global_shape)) {
133 const auto expanded_shape = std::get<0>(expanded_and_true_global_shape);
134 const auto expected_shape = std::get<1>(expanded_and_true_global_shape);
135 // If any of the shape has unknown dimension, do not check/validate the
136 // shape.
137 if (expanded_shape <= 0 || expected_shape <= 0) continue;
138
139 if (expanded_shape != expected_shape) {
140 return errors::Internal(
141 "SPMD expansion resulted in op output inconsistent with the "
142 "provided layout.");
143 }
144 }
145 }
146
147 return OkStatus();
148 }
149
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)150 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward(
151 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
152 return errors::Unimplemented(
153 "ComputeLayoutForward API must be implemented via the subclass.");
154 }
155
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)156 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutForward(
157 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
158 const llvm::DenseMap<int, Layout>& output_layouts) {
159 return ComputeLayoutForward(op, input_layouts);
160 }
161
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)162 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward(
163 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
164 return errors::Unimplemented(
165 "ComputeLayoutBackward API must be implemented via the subclass.");
166 }
167
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts,const llvm::DenseMap<int,Layout> & output_layouts)168 StatusOr<llvm::DenseMap<int, Layout>> SPMDExpanderBase::ComputeLayoutBackward(
169 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts,
170 const llvm::DenseMap<int, Layout>& output_layouts) {
171 return ComputeLayoutBackward(op, output_layouts);
172 }
173
RunSPMDExpansion(mlir::Operation * op,mlir::Operation ** output)174 Status RunSPMDExpansion(mlir::Operation* op, mlir::Operation** output) {
175 SPMDExpanderBase* expander =
176 SPMDExpanderRegistry::Global()->GetPropagateFnForOp(op);
177 if (expander != nullptr) {
178 return expander->ExpandOpAndSetLayout(op, output);
179 } else {
180 VLOG(1) << "No expansion found for " << OpName(op) << "\n";
181 *output = op;
182 }
183 return OkStatus();
184 }
185
186 } // namespace dtensor
187 } // namespace tensorflow
188