1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/dataparallel_spmd_expander.h"
17
18 #include <algorithm>
19 #include <string>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Operation.h" // from @llvm-project
30 #include "mlir/IR/Value.h" // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/dtensor/cc/dstatus.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/collectives.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/op_utils.h"
39 #include "tensorflow/dtensor/mlir/shape_utils.h"
40 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41 #include "tensorflow/dtensor/mlir/value_utils.h"
42 #include "tensorflow/dtensor/proto/layout.pb.h"
43
44 namespace tensorflow {
45 namespace dtensor {
46
47 namespace {
48
49 // Checks that all layouts are fully replicated
AllReplicated(const std::vector<Layout> & layouts)50 bool AllReplicated(const std::vector<Layout>& layouts) {
51 for (const auto& layout : layouts) {
52 if (!layout.IsFullyReplicated()) return false;
53 }
54 return true;
55 }
56
57 // Check all tensors are batch parallel
AllBatchParallel(const std::vector<Layout> & layouts,const llvm::DenseMap<int,int> & batchable_indices)58 bool AllBatchParallel(const std::vector<Layout>& layouts,
59 const llvm::DenseMap<int, int>& batchable_indices) {
60 for (int i = 0; i < layouts.size(); ++i) {
61 if (!layouts[i].IsBatchParallel(batchable_indices.lookup(i))) return false;
62 }
63 return true;
64 }
65 // Check all Layouts have the same batch rank
SameBatchRank(const std::vector<Layout> & layouts,const llvm::DenseMap<int,int> & batchable_indices)66 bool SameBatchRank(const std::vector<Layout>& layouts,
67 const llvm::DenseMap<int, int>& batchable_indices) {
68 absl::flat_hash_set<int> batch_ranks;
69 // Add all batch ranks of layouts
70 for (auto const& idx_and_non_batch_rank : batchable_indices) {
71 auto const& idx = idx_and_non_batch_rank.first;
72 auto const& non_batch_rank = idx_and_non_batch_rank.second;
73 batch_ranks.insert(layouts[idx].rank() - non_batch_rank);
74 }
75 return batch_ranks.size() <= 1;
76 }
77
78 // Check if any layout from a set of indices is not a nullopt
AnyLayoutExist(const llvm::DenseMap<int,Layout> & layouts,const llvm::DenseMap<int,int> & indices)79 bool AnyLayoutExist(const llvm::DenseMap<int, Layout>& layouts,
80 const llvm::DenseMap<int, int>& indices) {
81 for (auto const& idx_and_unused : indices) {
82 auto const& idx = idx_and_unused.first;
83 if (layouts.find(idx) != layouts.end()) return true;
84 }
85 return false;
86 }
87 // Given layouts to merge and a map of {indices of the batchable layouts, rank
88 // of non batch dimensions} merges those batchable layouts to produce one single
89 // layout. Assumes all batch ranks are the same for all batchable layouts, which
90 // is enforced before this
91 //
92 // Merged together so that we have the layout which is sharded in a tensor dim
93 // if and only if all layouts are sharded in the same sharding_spec.
MergeBatchLayouts(const llvm::DenseMap<int,Layout> & layouts,const llvm::DenseMap<int,int> & batchable_args,const Mesh & mesh)94 StatusOr<Layout> MergeBatchLayouts(
95 const llvm::DenseMap<int, Layout>& layouts,
96 const llvm::DenseMap<int, int>& batchable_args, const Mesh& mesh) {
97 // Get the batch rank
98 int layout_idx = -1;
99 for (auto const& idx_and_unused : batchable_args) {
100 auto const& idx = idx_and_unused.first;
101 if (layouts.find(idx) != layouts.end()) layout_idx = idx;
102 }
103
104 int batch_rank =
105 layouts.lookup(layout_idx).rank() - batchable_args.lookup(layout_idx);
106 // Initialize with replicated
107 std::vector<std::string> merged_specs(batch_rank, Layout::kUnshardedDim);
108
109 // Merge layouts. If any dimension don't agree on sharding dim, then replicate
110 for (int i = 0; i < batch_rank; ++i) {
111 absl::flat_hash_set<std::string> spec_set;
112 for (auto const& arg_idx_and_unused : batchable_args) {
113 auto const& arg_idx = arg_idx_and_unused.first;
114 if (layouts.find(arg_idx) == layouts.end()) continue;
115 const std::string spec = layouts.lookup(arg_idx).sharding_spec(i);
116 if (spec != Layout::kUnshardedDim) {
117 spec_set.insert(spec);
118 }
119 }
120 if (spec_set.size() == 1) {
121 merged_specs[i] = *spec_set.begin();
122 } else {
123 merged_specs[i] = Layout::kUnshardedDim;
124 }
125 }
126
127 // Deduplicate same usage of mesh dims. [x,x] -> [unsharded, unsharded]
128 absl::flat_hash_map<std::string, int> counter;
129 for (const std::string& spec : merged_specs) counter[spec] += 1;
130 for (std::string& spec : merged_specs) {
131 if (counter[spec] > 1) {
132 spec = Layout::kUnshardedDim;
133 }
134 }
135 return Layout::GetLayout(merged_specs, mesh);
136 }
137
138 // Choose an intermediate layout to relayout. Picks the most frequently
139 // sharded mesh dimension for every batch dimension, then deduplicates (n-1)
140 // of all repeated mesh dimensions, leaving the rightmost duplicate sharded
141 //
142 // Note that this assumes the number of batch dims for every batchable
143 // tensor is the same and is enforced before this
144 //
145 // Examples:
146 // Given layouts: [x,y],[x,z],[y,z], produces [x,z]
147 // Deduplication: [x,x] -> will become [*, x]
IntermediateBatchLayout(const std::vector<Layout> & operand_layouts,const llvm::DenseMap<int,int> & batchable_operands,const std::vector<Layout> & output_layouts,const llvm::DenseMap<int,int> & batchable_outputs,const Mesh & mesh)148 StatusOr<Layout> IntermediateBatchLayout(
149 const std::vector<Layout>& operand_layouts,
150 const llvm::DenseMap<int, int>& batchable_operands,
151 const std::vector<Layout>& output_layouts,
152 const llvm::DenseMap<int, int>& batchable_outputs, const Mesh& mesh) {
153 if (batchable_operands.empty()) {
154 return errors::Unimplemented(
155 llvm::formatv("There must be at least one batchable operand").str());
156 }
157 int first_batcharg_index = batchable_outputs.begin()->first;
158 int batch_rank = operand_layouts[first_batcharg_index].rank() -
159 batchable_operands.find(first_batcharg_index)->second;
160
161 std::vector<std::string> batch_specs(batch_rank, Layout::kUnshardedDim);
162
163 // For each batch dimension, finds the most commonly used mesh dimension
164 // and sets that to batch_specs[i].
165 for (int i = 0; i < batch_rank; ++i) {
166 std::string mesh_dim = Layout::kUnshardedDim;
167 int max_count = 0;
168 absl::flat_hash_map<std::string, int> counter;
169 // add operand counts
170 for (auto const& idx_and_unused : batchable_operands) {
171 auto const& idx = idx_and_unused.first;
172 std::string spec = operand_layouts[idx].sharding_spec(i);
173 if (spec != Layout::kUnshardedDim) counter[spec]++;
174 if (counter[spec] > max_count) {
175 max_count = counter[spec];
176 mesh_dim = spec;
177 }
178 }
179 // add output counts
180 for (auto const& idx_and_unused : batchable_outputs) {
181 auto const& idx = idx_and_unused.first;
182 std::string spec = output_layouts[idx].sharding_spec(i);
183 if (spec != Layout::kUnshardedDim) counter[spec]++;
184 if (counter[spec] > max_count) {
185 max_count = counter[spec];
186 mesh_dim = spec;
187 }
188 }
189 batch_specs[i] = mesh_dim;
190 }
191 // deduplicate
192 absl::flat_hash_map<std::string, int> counter;
193 for (const std::string& spec : batch_specs) counter[spec] += 1;
194 for (std::string& spec : batch_specs) {
195 if (counter[spec] > 1) {
196 counter[spec]--;
197 spec = Layout::kUnshardedDim;
198 }
199 }
200 return Layout::GetLayout(batch_specs, mesh);
201 }
202 } // namespace
203
204 // Relayout all operands that have batch dimensions to batch sharded
205 // The outputs will get the correct inferred shape from the operands
RelayoutOperandsAndOutputs(mlir::Operation * op,const std::vector<Layout> & operand_layouts,const std::vector<Layout> & output_layouts)206 StatusOr<mlir::Operation*> DataparallelSPMDExpander::RelayoutOperandsAndOutputs(
207 mlir::Operation* op, const std::vector<Layout>& operand_layouts,
208 const std::vector<Layout>& output_layouts) {
209 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
210 TF_ASSIGN_OR_RETURN(
211 const Layout intermediate_batch_layout,
212 IntermediateBatchLayout(operand_layouts, batchable_operands_,
213 output_layouts, batchable_outputs_, mesh));
214 // Relayout batchable operands
215 for (auto i = 0; i < operand_layouts.size(); ++i) {
216 // Relayout operands that have a batch dimension to intermediate layout
217 if (batchable_operands_.find(i) != batchable_operands_.end()) {
218 int replicated_rank =
219 ValueRank(op->getOperand(i)) - intermediate_batch_layout.rank();
220 TF_ASSIGN_OR_RETURN(
221 auto new_layout,
222 ConcatenateLayouts(intermediate_batch_layout,
223 Layout::ReplicatedOnMesh(mesh, replicated_rank)));
224 TF_ASSIGN_OR_RETURN(
225 const auto new_operand,
226 EmitRelayout(op->getOperand(i), operand_layouts[i], new_layout));
227 op->setOperand(i, new_operand);
228 }
229 }
230 // Expand to local shape
231 op = InferSPMDExpandedLocalShape(op);
232
233 llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
234 llvm::SmallVector<mlir::Value, 4> generated_outputs;
235 llvm::SmallVector<mlir::Type, 4> generated_types;
236
237 // Track the op that comes last after splitting.
238 mlir::Operation* last_op_after_splitting = op;
239
240 // Relayout batchable outputs
241 for (auto i = 0; i < output_layouts.size(); ++i) {
242 // Relayout to batch shard if tensor has batch dim
243 if (batchable_outputs_.find(i) != batchable_outputs_.end()) {
244 int replicated_rank =
245 ValueRank(op->getResult(i)) - intermediate_batch_layout.rank();
246 TF_ASSIGN_OR_RETURN(
247 auto new_layout,
248 ConcatenateLayouts(intermediate_batch_layout,
249 Layout::ReplicatedOnMesh(mesh, replicated_rank)));
250 TF_ASSIGN_OR_RETURN(auto new_output,
251 EmitRelayout(op->getOpResult(i), new_layout,
252 output_layouts[i], &newly_created_ops));
253 generated_outputs.emplace_back(new_output);
254 generated_types.emplace_back(new_output.getType());
255 if (last_op_after_splitting->isBeforeInBlock(
256 new_output.getDefiningOp())) {
257 last_op_after_splitting = new_output.getDefiningOp();
258 }
259 } else {
260 generated_outputs.push_back(op->getResult(i));
261 generated_types.push_back(op->getResult(i).getType());
262 }
263 }
264 mlir::OpBuilder builder(op);
265 builder.setInsertionPointAfter(last_op_after_splitting);
266
267 // Tie all outputs together with identity_n
268 auto identity_op = builder.create<mlir::TF::IdentityNOp>(
269 op->getLoc(), generated_types, generated_outputs);
270 newly_created_ops.insert(identity_op);
271 for (int i = 0; i < output_layouts.size(); ++i) {
272 op->getOpResult(i).replaceAllUsesExcept(identity_op.getResult(i),
273 newly_created_ops);
274 }
275
276 return identity_op.getOperation();
277 }
278
ExpandOp(mlir::Operation * op)279 StatusOr<mlir::Operation*> DataparallelSPMDExpander::ExpandOp(
280 mlir::Operation* op) {
281 TF_ASSIGN_OR_RETURN(const auto output_layouts,
282 ExtractRequiredLayoutFromOp(op));
283 TF_ASSIGN_OR_RETURN(const auto operand_layouts,
284 ExtractRequiredLayoutFromOperands(op));
285 // Check all input and output are batch parallel
286 if (!AllBatchParallel(operand_layouts, batchable_operands_) ||
287 !AllBatchParallel(output_layouts, batchable_outputs_)) {
288 return errors::Unimplemented(
289 llvm::formatv("All operands and outputs must be batch parallel.")
290 .str());
291 }
292 // Check that the rank of batch dimensions are same for all batchable tensors
293 if (!SameBatchRank(operand_layouts, batchable_operands_) ||
294 !SameBatchRank(output_layouts, batchable_outputs_)) {
295 return errors::Unimplemented(
296 llvm::formatv("All operands and outputs with batch dimensions must "
297 "have same batch dimension rank")
298 .str());
299 }
300 if (AllReplicated(output_layouts) && AllReplicated(operand_layouts))
301 return InferSPMDExpandedLocalShape(op);
302 return RelayoutOperandsAndOutputs(op, operand_layouts, output_layouts);
303 }
304
305 // Take all layouts of batchable operands, and merge them to produce a single
306 // layout for all batchable outputs.
307 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)308 DataparallelSPMDExpander::ComputeLayoutForward(
309 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
310 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
311
312 llvm::DenseMap<int, Layout> output_layouts;
313 // Compute output layouts
314 if (AnyLayoutExist(input_layouts, batchable_operands_)) {
315 TF_ASSIGN_OR_RETURN(
316 const Layout& batch_output_layout,
317 MergeBatchLayouts(input_layouts, batchable_operands_, mesh));
318 for (const auto& output_and_index : llvm::enumerate(op->getOpResults())) {
319 const int output_index = output_and_index.index();
320 auto output = output_and_index.value();
321 int rank = ValueRank(output);
322 if (batchable_outputs_.find(output_index) != batchable_outputs_.end()) {
323 int replicated_rank = batchable_outputs_[output_index];
324 TF_ASSIGN_OR_RETURN(auto new_layout,
325 ConcatenateLayouts(batch_output_layout,
326 Layout::ReplicatedOnMesh(
327 mesh, replicated_rank)));
328 output_layouts[output_index] = new_layout;
329 } else {
330 output_layouts[output_index] = Layout::ReplicatedOnMesh(mesh, rank);
331 }
332 }
333 }
334 return output_layouts;
335 }
336
337 // Take all layouts of batchable outputs, and merge them to produce a single
338 // layout for all batchable operands.
339 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)340 DataparallelSPMDExpander::ComputeLayoutBackward(
341 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
342 TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
343
344 llvm::DenseMap<int, Layout> input_layouts;
345 // Compute input layouts in the following way: For operand indices that
346 // have a batch dimension, batch shard it the same way as the output layouts.
347 // Otherwise, replicate.
348 if (AnyLayoutExist(output_layouts, batchable_outputs_)) {
349 TF_ASSIGN_OR_RETURN(
350 const Layout& batch_operand_layout,
351 MergeBatchLayouts(output_layouts, batchable_outputs_, mesh));
352
353 for (const auto& operand_and_index : llvm::enumerate(op->getOperands())) {
354 const int operand_index = operand_and_index.index();
355 auto operand = operand_and_index.value();
356 int rank = ValueRank(operand);
357 if (batchable_operands_.find(operand_index) !=
358 batchable_operands_.end()) {
359 int replicated_rank = batchable_operands_[operand_index];
360 TF_ASSIGN_OR_RETURN(auto new_layout,
361 ConcatenateLayouts(batch_operand_layout,
362 Layout::ReplicatedOnMesh(
363 mesh, replicated_rank)));
364 input_layouts[operand_index] = new_layout;
365 } else {
366 input_layouts[operand_index] = Layout::ReplicatedOnMesh(mesh, rank);
367 }
368 }
369 }
370 return input_layouts;
371 }
372 } // namespace dtensor
373 } // namespace tensorflow
374