xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/dataparallel_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/dataparallel_spmd_expander.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/types/optional.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
29 #include "mlir/IR/Operation.h"  // from @llvm-project
30 #include "mlir/IR/Value.h"  // from @llvm-project
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
33 #include "tensorflow/core/platform/errors.h"
34 #include "tensorflow/dtensor/cc/dstatus.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/collectives.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/op_utils.h"
39 #include "tensorflow/dtensor/mlir/shape_utils.h"
40 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
41 #include "tensorflow/dtensor/mlir/value_utils.h"
42 #include "tensorflow/dtensor/proto/layout.pb.h"
43 
44 namespace tensorflow {
45 namespace dtensor {
46 
47 namespace {
48 
49 // Checks that all layouts are fully replicated
AllReplicated(const std::vector<Layout> & layouts)50 bool AllReplicated(const std::vector<Layout>& layouts) {
51   for (const auto& layout : layouts) {
52     if (!layout.IsFullyReplicated()) return false;
53   }
54   return true;
55 }
56 
57 // Check all tensors are batch parallel
AllBatchParallel(const std::vector<Layout> & layouts,const llvm::DenseMap<int,int> & batchable_indices)58 bool AllBatchParallel(const std::vector<Layout>& layouts,
59                       const llvm::DenseMap<int, int>& batchable_indices) {
60   for (int i = 0; i < layouts.size(); ++i) {
61     if (!layouts[i].IsBatchParallel(batchable_indices.lookup(i))) return false;
62   }
63   return true;
64 }
65 // Check all Layouts have the same batch rank
SameBatchRank(const std::vector<Layout> & layouts,const llvm::DenseMap<int,int> & batchable_indices)66 bool SameBatchRank(const std::vector<Layout>& layouts,
67                    const llvm::DenseMap<int, int>& batchable_indices) {
68   absl::flat_hash_set<int> batch_ranks;
69   // Add all batch ranks of layouts
70   for (auto const& idx_and_non_batch_rank : batchable_indices) {
71     auto const& idx = idx_and_non_batch_rank.first;
72     auto const& non_batch_rank = idx_and_non_batch_rank.second;
73     batch_ranks.insert(layouts[idx].rank() - non_batch_rank);
74   }
75   return batch_ranks.size() <= 1;
76 }
77 
78 // Check if any layout from a set of indices is not a nullopt
AnyLayoutExist(const llvm::DenseMap<int,Layout> & layouts,const llvm::DenseMap<int,int> & indices)79 bool AnyLayoutExist(const llvm::DenseMap<int, Layout>& layouts,
80                     const llvm::DenseMap<int, int>& indices) {
81   for (auto const& idx_and_unused : indices) {
82     auto const& idx = idx_and_unused.first;
83     if (layouts.find(idx) != layouts.end()) return true;
84   }
85   return false;
86 }
87 // Given layouts to merge and a map of {indices of the batchable layouts, rank
88 // of non batch dimensions} merges those batchable layouts to produce one single
89 // layout. Assumes all batch ranks are the same for all batchable layouts, which
90 // is enforced before this
91 //
92 // Merged together so that we have the layout which is sharded in a tensor dim
93 // if and only if all layouts are sharded in the same sharding_spec.
MergeBatchLayouts(const llvm::DenseMap<int,Layout> & layouts,const llvm::DenseMap<int,int> & batchable_args,const Mesh & mesh)94 StatusOr<Layout> MergeBatchLayouts(
95     const llvm::DenseMap<int, Layout>& layouts,
96     const llvm::DenseMap<int, int>& batchable_args, const Mesh& mesh) {
97   // Get the batch rank
98   int layout_idx = -1;
99   for (auto const& idx_and_unused : batchable_args) {
100     auto const& idx = idx_and_unused.first;
101     if (layouts.find(idx) != layouts.end()) layout_idx = idx;
102   }
103 
104   int batch_rank =
105       layouts.lookup(layout_idx).rank() - batchable_args.lookup(layout_idx);
106   // Initialize with replicated
107   std::vector<std::string> merged_specs(batch_rank, Layout::kUnshardedDim);
108 
109   // Merge layouts. If any dimension don't agree on sharding dim, then replicate
110   for (int i = 0; i < batch_rank; ++i) {
111     absl::flat_hash_set<std::string> spec_set;
112     for (auto const& arg_idx_and_unused : batchable_args) {
113       auto const& arg_idx = arg_idx_and_unused.first;
114       if (layouts.find(arg_idx) == layouts.end()) continue;
115       const std::string spec = layouts.lookup(arg_idx).sharding_spec(i);
116       if (spec != Layout::kUnshardedDim) {
117         spec_set.insert(spec);
118       }
119     }
120     if (spec_set.size() == 1) {
121       merged_specs[i] = *spec_set.begin();
122     } else {
123       merged_specs[i] = Layout::kUnshardedDim;
124     }
125   }
126 
127   // Deduplicate same usage of mesh dims. [x,x] -> [unsharded, unsharded]
128   absl::flat_hash_map<std::string, int> counter;
129   for (const std::string& spec : merged_specs) counter[spec] += 1;
130   for (std::string& spec : merged_specs) {
131     if (counter[spec] > 1) {
132       spec = Layout::kUnshardedDim;
133     }
134   }
135   return Layout::GetLayout(merged_specs, mesh);
136 }
137 
138 // Choose an intermediate layout to relayout. Picks the most frequently
139 // sharded mesh dimension for every batch dimension, then deduplicates (n-1)
140 // of all repeated mesh dimensions, leaving the rightmost duplicate sharded
141 //
142 // Note that this assumes the number of batch dims for every batchable
143 // tensor is the same and is enforced before this
144 //
145 // Examples:
146 // Given layouts: [x,y],[x,z],[y,z], produces [x,z]
147 // Deduplication: [x,x] -> will become [*, x]
IntermediateBatchLayout(const std::vector<Layout> & operand_layouts,const llvm::DenseMap<int,int> & batchable_operands,const std::vector<Layout> & output_layouts,const llvm::DenseMap<int,int> & batchable_outputs,const Mesh & mesh)148 StatusOr<Layout> IntermediateBatchLayout(
149     const std::vector<Layout>& operand_layouts,
150     const llvm::DenseMap<int, int>& batchable_operands,
151     const std::vector<Layout>& output_layouts,
152     const llvm::DenseMap<int, int>& batchable_outputs, const Mesh& mesh) {
153   if (batchable_operands.empty()) {
154     return errors::Unimplemented(
155         llvm::formatv("There must be at least one batchable operand").str());
156   }
157   int first_batcharg_index = batchable_outputs.begin()->first;
158   int batch_rank = operand_layouts[first_batcharg_index].rank() -
159                    batchable_operands.find(first_batcharg_index)->second;
160 
161   std::vector<std::string> batch_specs(batch_rank, Layout::kUnshardedDim);
162 
163   // For each batch dimension, finds the most commonly used mesh dimension
164   // and sets that to batch_specs[i].
165   for (int i = 0; i < batch_rank; ++i) {
166     std::string mesh_dim = Layout::kUnshardedDim;
167     int max_count = 0;
168     absl::flat_hash_map<std::string, int> counter;
169     // add operand counts
170     for (auto const& idx_and_unused : batchable_operands) {
171       auto const& idx = idx_and_unused.first;
172       std::string spec = operand_layouts[idx].sharding_spec(i);
173       if (spec != Layout::kUnshardedDim) counter[spec]++;
174       if (counter[spec] > max_count) {
175         max_count = counter[spec];
176         mesh_dim = spec;
177       }
178     }
179     // add output counts
180     for (auto const& idx_and_unused : batchable_outputs) {
181       auto const& idx = idx_and_unused.first;
182       std::string spec = output_layouts[idx].sharding_spec(i);
183       if (spec != Layout::kUnshardedDim) counter[spec]++;
184       if (counter[spec] > max_count) {
185         max_count = counter[spec];
186         mesh_dim = spec;
187       }
188     }
189     batch_specs[i] = mesh_dim;
190   }
191   // deduplicate
192   absl::flat_hash_map<std::string, int> counter;
193   for (const std::string& spec : batch_specs) counter[spec] += 1;
194   for (std::string& spec : batch_specs) {
195     if (counter[spec] > 1) {
196       counter[spec]--;
197       spec = Layout::kUnshardedDim;
198     }
199   }
200   return Layout::GetLayout(batch_specs, mesh);
201 }
202 }  // namespace
203 
204 // Relayout all operands that have batch dimensions to batch sharded
205 // The outputs will get the correct inferred shape from the operands
RelayoutOperandsAndOutputs(mlir::Operation * op,const std::vector<Layout> & operand_layouts,const std::vector<Layout> & output_layouts)206 StatusOr<mlir::Operation*> DataparallelSPMDExpander::RelayoutOperandsAndOutputs(
207     mlir::Operation* op, const std::vector<Layout>& operand_layouts,
208     const std::vector<Layout>& output_layouts) {
209   TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
210   TF_ASSIGN_OR_RETURN(
211       const Layout intermediate_batch_layout,
212       IntermediateBatchLayout(operand_layouts, batchable_operands_,
213                               output_layouts, batchable_outputs_, mesh));
214   // Relayout batchable operands
215   for (auto i = 0; i < operand_layouts.size(); ++i) {
216     // Relayout operands that have a batch dimension to intermediate layout
217     if (batchable_operands_.find(i) != batchable_operands_.end()) {
218       int replicated_rank =
219           ValueRank(op->getOperand(i)) - intermediate_batch_layout.rank();
220       TF_ASSIGN_OR_RETURN(
221           auto new_layout,
222           ConcatenateLayouts(intermediate_batch_layout,
223                              Layout::ReplicatedOnMesh(mesh, replicated_rank)));
224       TF_ASSIGN_OR_RETURN(
225           const auto new_operand,
226           EmitRelayout(op->getOperand(i), operand_layouts[i], new_layout));
227       op->setOperand(i, new_operand);
228     }
229   }
230   // Expand to local shape
231   op = InferSPMDExpandedLocalShape(op);
232 
233   llvm::SmallPtrSet<mlir::Operation*, 4> newly_created_ops;
234   llvm::SmallVector<mlir::Value, 4> generated_outputs;
235   llvm::SmallVector<mlir::Type, 4> generated_types;
236 
237   // Track the op that comes last after splitting.
238   mlir::Operation* last_op_after_splitting = op;
239 
240   // Relayout batchable outputs
241   for (auto i = 0; i < output_layouts.size(); ++i) {
242     // Relayout to batch shard if tensor has batch dim
243     if (batchable_outputs_.find(i) != batchable_outputs_.end()) {
244       int replicated_rank =
245           ValueRank(op->getResult(i)) - intermediate_batch_layout.rank();
246       TF_ASSIGN_OR_RETURN(
247           auto new_layout,
248           ConcatenateLayouts(intermediate_batch_layout,
249                              Layout::ReplicatedOnMesh(mesh, replicated_rank)));
250       TF_ASSIGN_OR_RETURN(auto new_output,
251                           EmitRelayout(op->getOpResult(i), new_layout,
252                                        output_layouts[i], &newly_created_ops));
253       generated_outputs.emplace_back(new_output);
254       generated_types.emplace_back(new_output.getType());
255       if (last_op_after_splitting->isBeforeInBlock(
256               new_output.getDefiningOp())) {
257         last_op_after_splitting = new_output.getDefiningOp();
258       }
259     } else {
260       generated_outputs.push_back(op->getResult(i));
261       generated_types.push_back(op->getResult(i).getType());
262     }
263   }
264   mlir::OpBuilder builder(op);
265   builder.setInsertionPointAfter(last_op_after_splitting);
266 
267   // Tie all outputs together with identity_n
268   auto identity_op = builder.create<mlir::TF::IdentityNOp>(
269       op->getLoc(), generated_types, generated_outputs);
270   newly_created_ops.insert(identity_op);
271   for (int i = 0; i < output_layouts.size(); ++i) {
272     op->getOpResult(i).replaceAllUsesExcept(identity_op.getResult(i),
273                                             newly_created_ops);
274   }
275 
276   return identity_op.getOperation();
277 }
278 
ExpandOp(mlir::Operation * op)279 StatusOr<mlir::Operation*> DataparallelSPMDExpander::ExpandOp(
280     mlir::Operation* op) {
281   TF_ASSIGN_OR_RETURN(const auto output_layouts,
282                       ExtractRequiredLayoutFromOp(op));
283   TF_ASSIGN_OR_RETURN(const auto operand_layouts,
284                       ExtractRequiredLayoutFromOperands(op));
285   // Check all input and output are batch parallel
286   if (!AllBatchParallel(operand_layouts, batchable_operands_) ||
287       !AllBatchParallel(output_layouts, batchable_outputs_)) {
288     return errors::Unimplemented(
289         llvm::formatv("All operands and outputs must be batch parallel.")
290             .str());
291   }
292   // Check that the rank of batch dimensions are same for all batchable tensors
293   if (!SameBatchRank(operand_layouts, batchable_operands_) ||
294       !SameBatchRank(output_layouts, batchable_outputs_)) {
295     return errors::Unimplemented(
296         llvm::formatv("All operands and outputs with batch dimensions must "
297                       "have same batch dimension rank")
298             .str());
299   }
300   if (AllReplicated(output_layouts) && AllReplicated(operand_layouts))
301     return InferSPMDExpandedLocalShape(op);
302   return RelayoutOperandsAndOutputs(op, operand_layouts, output_layouts);
303 }
304 
305 // Take all layouts of batchable operands, and merge them to produce a single
306 // layout for all batchable outputs.
307 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)308 DataparallelSPMDExpander::ComputeLayoutForward(
309     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
310   TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
311 
312   llvm::DenseMap<int, Layout> output_layouts;
313   // Compute output layouts
314   if (AnyLayoutExist(input_layouts, batchable_operands_)) {
315     TF_ASSIGN_OR_RETURN(
316         const Layout& batch_output_layout,
317         MergeBatchLayouts(input_layouts, batchable_operands_, mesh));
318     for (const auto& output_and_index : llvm::enumerate(op->getOpResults())) {
319       const int output_index = output_and_index.index();
320       auto output = output_and_index.value();
321       int rank = ValueRank(output);
322       if (batchable_outputs_.find(output_index) != batchable_outputs_.end()) {
323         int replicated_rank = batchable_outputs_[output_index];
324         TF_ASSIGN_OR_RETURN(auto new_layout,
325                             ConcatenateLayouts(batch_output_layout,
326                                                Layout::ReplicatedOnMesh(
327                                                    mesh, replicated_rank)));
328         output_layouts[output_index] = new_layout;
329       } else {
330         output_layouts[output_index] = Layout::ReplicatedOnMesh(mesh, rank);
331       }
332     }
333   }
334   return output_layouts;
335 }
336 
337 // Take all layouts of batchable outputs, and merge them to produce a single
338 // layout for all batchable operands.
339 StatusOr<llvm::DenseMap<int, Layout>>
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)340 DataparallelSPMDExpander::ComputeLayoutBackward(
341     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
342   TF_ASSIGN_OR_RETURN(const auto mesh, ExtractDeviceMeshEnclosingCluster(op));
343 
344   llvm::DenseMap<int, Layout> input_layouts;
345   // Compute input layouts in the following way: For operand indices that
346   // have a batch dimension, batch shard it the same way as the output layouts.
347   // Otherwise, replicate.
348   if (AnyLayoutExist(output_layouts, batchable_outputs_)) {
349     TF_ASSIGN_OR_RETURN(
350         const Layout& batch_operand_layout,
351         MergeBatchLayouts(output_layouts, batchable_outputs_, mesh));
352 
353     for (const auto& operand_and_index : llvm::enumerate(op->getOperands())) {
354       const int operand_index = operand_and_index.index();
355       auto operand = operand_and_index.value();
356       int rank = ValueRank(operand);
357       if (batchable_operands_.find(operand_index) !=
358           batchable_operands_.end()) {
359         int replicated_rank = batchable_operands_[operand_index];
360         TF_ASSIGN_OR_RETURN(auto new_layout,
361                             ConcatenateLayouts(batch_operand_layout,
362                                                Layout::ReplicatedOnMesh(
363                                                    mesh, replicated_rank)));
364         input_layouts[operand_index] = new_layout;
365       } else {
366         input_layouts[operand_index] = Layout::ReplicatedOnMesh(mesh, rank);
367       }
368     }
369   }
370   return input_layouts;
371 }
372 }  // namespace dtensor
373 }  // namespace tensorflow
374