1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h"
17
18 #include <string>
19
20 #include "absl/strings/string_view.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "tensorflow/dtensor/cc/tensor_layout.h"
25 #include "tensorflow/dtensor/mlir/collectives.h"
26 #include "tensorflow/dtensor/mlir/layout_parsing.h"
27 #include "tensorflow/dtensor/mlir/shape_utils.h"
28 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
29 #include "tensorflow/dtensor/mlir/value_utils.h"
30
31 namespace tensorflow {
32 namespace dtensor {
33 namespace {
34
Equal(const ShardingSpec & a,const ShardingSpec & b)35 bool Equal(const ShardingSpec& a, const ShardingSpec& b) {
36 return a.sharding_spec() == b.sharding_spec();
37 }
38
39 } // namespace
40
41 // Einsum, like reductions, is implemented as a local operation followed by
42 // an all-reduce over dimensions that have been reduced.
ExpandOp(mlir::Operation * op)43 StatusOr<mlir::Operation*> EinsumSPMDExpander::ExpandOp(mlir::Operation* op) {
44 std::vector<Layout> input_layouts(op->getNumOperands());
45 for (int i = 0; i < op->getNumOperands(); ++i) {
46 TF_ASSIGN_OR_RETURN(auto layout,
47 ExtractLayoutFromOperand(op->getOperand(i)));
48 if (!layout) return errors::InvalidArgument("missing layout for input ", i);
49 input_layouts[i] = layout.value();
50 }
51 TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
52 if (!output_layout)
53 return errors::InvalidArgument("is missing output layout.");
54
55 std::vector<mlir::Value> new_inputs;
56 Layout layout_after_einsum;
57 absl::flat_hash_set<std::string> reduce_dims;
58 TF_RETURN_IF_ERROR(MaybeRelayoutInputs(input_layouts, op,
59 output_layout.value(), reduce_dims,
60 layout_after_einsum, new_inputs));
61
62 mlir::OpBuilder builder(op);
63 mlir::BlockAndValueMapping mapping;
64 for (int i = 0; i < op->getNumOperands(); ++i)
65 mapping.map(op->getOperand(i), new_inputs[i]);
66 mlir::Operation* new_op = builder.clone(*op, mapping);
67 // Note that the output shape of new_op is cloned from op, so we need to
68 // update to the local shape.
69 new_op = InferSPMDExpandedLocalShape(new_op);
70
71 if (!reduce_dims.empty()) {
72 TF_ASSIGN_OR_RETURN(
73 new_op, EmitAllReduce(builder, layout_after_einsum, reduce_dims, new_op,
74 kReduceOpAdd));
75 }
76
77 TF_ASSIGN_OR_RETURN(auto final_output,
78 EmitRelayout(new_op->getOpResult(0), layout_after_einsum,
79 output_layout.value()));
80
81 op->getOpResult(0).replaceAllUsesWith(final_output);
82 op->erase();
83
84 return final_output.getDefiningOp();
85 }
86
87 // TODO(power) -- we use a simplified equation parser here. consider
88 // refactoring einsum_spmd_expander and reusing the TF parser.
89 //
90 // Given the input equation, this has 3 outputs:
91 // reduced_dims: The set of mesh dimesions that we need to all reduce over.
92 // input_mappings: for each equation input, the map from the equation labels
93 // to the tensor dimension of that label.
94 // output_mapping: as above, but for the equation output.
ExtractEquationRelations(absl::string_view equation,absl::flat_hash_set<char> & reduced_dims,std::vector<absl::flat_hash_map<char,std::vector<int>>> & input_mappings,absl::flat_hash_map<char,std::vector<int>> & output_mapping)95 Status ExtractEquationRelations(
96 absl::string_view equation, absl::flat_hash_set<char>& reduced_dims,
97 std::vector<absl::flat_hash_map<char, std::vector<int>>>& input_mappings,
98 absl::flat_hash_map<char, std::vector<int>>& output_mapping) {
99 std::pair<std::string, std::string> parts = absl::StrSplit(equation, "->");
100 absl::flat_hash_set<char> non_reduced_dims;
101
102 // Mark kept dimensions from the output.
103 for (const auto& char_and_index : llvm::enumerate(parts.second)) {
104 // TODO(b/172691887): Support Broadcasting for einsum.
105 if (char_and_index.value() == '.')
106 return errors::Unimplemented(
107 "Broadcasting is unimplemented for einsum. Received equation ",
108 equation);
109 non_reduced_dims.insert(char_and_index.value());
110
111 // Construct the output mapping, note that output is not allowed to have
112 // duplicate labels. This would mean that the datatype of output_mapping
113 // should really be absl::flat_hash_map<char, int>, but having the same
114 // type as the input_mapping keeps GetSpecsFromLabelsAndMap simpler.
115 if (output_mapping.contains(char_and_index.value()))
116 return errors::InvalidArgument("received label ", char_and_index.value(),
117 " multiple times in the "
118 "output of einsum equation ",
119 equation);
120
121 output_mapping[char_and_index.value()].emplace_back(char_and_index.index());
122 }
123
124 std::vector<std::string> inputs = absl::StrSplit(parts.first, ',');
125 // Note that the TF einsum op only supports at most 2 inputs. This is slightly
126 // confusing as the tf.einsum interface actually supports > 2 inputs.
127 if (inputs.size() > 2)
128 return errors::InvalidArgument(
129 "einsum only supports at most 2 inputs received equation ", equation,
130 " which has ", inputs.size(), " inputs");
131
132 input_mappings.resize(inputs.size());
133
134 // Compute the input mappings and keep track of labels which are reduced.
135 for (int i = 0; i < inputs.size(); ++i) {
136 for (const auto& char_and_index : llvm::enumerate(inputs[i])) {
137 input_mappings[i][char_and_index.value()].emplace_back(
138 char_and_index.index());
139 if (!non_reduced_dims.contains(char_and_index.value()))
140 reduced_dims.insert(char_and_index.value());
141 }
142 }
143
144 return OkStatus();
145 }
146
147 // For a set of layouts and mappings from labels to offsets in the layouts,
148 // return a mappings of labels to ShardingSpecs.
149 // If the label appears multiples with different mesh dimensions in the
150 // sharding specs we raise an error if replicate_incompatible_dimensions is
151 // false. Otherwise we treat the dimension as if it were unsharded.
152 // Labels with unsharded dimensions are not recorded in the output.
GetLabelToShardingSpec(bool replicate_incompatible_dimensions,const std::vector<Layout> & layouts,const std::vector<absl::flat_hash_map<char,std::vector<int>>> & mappings)153 StatusOr<absl::flat_hash_map<char, ShardingSpec>> GetLabelToShardingSpec(
154 bool replicate_incompatible_dimensions, const std::vector<Layout>& layouts,
155 const std::vector<absl::flat_hash_map<char, std::vector<int>>>& mappings) {
156 absl::flat_hash_map<char, ShardingSpec> label_to_sharding_spec;
157 absl::flat_hash_set<char> incompatible_labels;
158
159 // For each mapping, identify the mesh dimension and whether it has been
160 // reduced away.
161 for (int index = 0; index < layouts.size(); ++index) {
162 for (const auto& mapping : mappings[index]) {
163 for (int offset : mapping.second) {
164 if (offset >= layouts[index].rank())
165 return errors::InvalidArgument(
166 llvm::formatv(
167 "specified einsum equation for operand {0} tried to "
168 "read layout at offset {1}, but layout is {2} with rank "
169 "{3}",
170 index, offset, layouts[index].ToString(),
171 layouts[index].rank())
172 .str());
173
174 const ShardingSpec& sharding_spec = layouts[index].dim(offset);
175
176 if (label_to_sharding_spec.contains(mapping.first)) {
177 if (Layout::IsShardedSpec(sharding_spec) &&
178 !Equal(label_to_sharding_spec[mapping.first], sharding_spec)) {
179 if (!replicate_incompatible_dimensions)
180 return errors::InvalidArgument(
181 llvm::formatv(
182 "incompatible mesh dimensions in equation, label '{0}' "
183 "is mapped to mesh dimension '{1}' and '{2}'",
184 mapping.first, sharding_spec.sharding_spec(),
185 label_to_sharding_spec[mapping.first].sharding_spec())
186 .str());
187 else
188 incompatible_labels.insert(mapping.first);
189 }
190 } else if (Layout::IsShardedSpec(sharding_spec)) {
191 label_to_sharding_spec[mapping.first] = sharding_spec;
192 }
193 }
194 }
195 }
196
197 // For labels that had incompatible dimensions, treat them as replicated.
198 // We would need to insert some all to all in the SPMD expansion for these.
199 for (char label : incompatible_labels) label_to_sharding_spec.erase(label);
200
201 return label_to_sharding_spec;
202 }
203
204 // The layout we generated may be invalid as the same dimension may be used
205 // multiple times. E.g. ab,bc->ac (i.e. matmul) with a and c sharded over the
206 // same dim. In this case we mark all such dimensions as replicated.
VerifyOrFixLayout(std::pair<std::vector<ShardingSpec>,absl::flat_hash_map<std::string,int>> pair,const Mesh & mesh)207 StatusOr<Layout> VerifyOrFixLayout(
208 std::pair<std::vector<ShardingSpec>, absl::flat_hash_map<std::string, int>>
209 pair,
210 const Mesh& mesh) {
211 std::vector<ShardingSpec> sharding_specs = pair.first;
212 absl::flat_hash_map<std::string, int> dimension_use_count = pair.second;
213 for (int i = 0; i < sharding_specs.size(); ++i)
214 if (Layout::IsShardedSpec(sharding_specs[i]) &&
215 dimension_use_count[sharding_specs[i].sharding_spec()] > 1)
216 sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim);
217 return Layout::GetLayout(sharding_specs, mesh);
218 }
219
220 // Construct a layout on a given mesh from the label to tensor dimension map
221 // and the label to mesh_dimension map.
222 std::pair<std::vector<ShardingSpec>, absl::flat_hash_map<std::string, int>>
GetSpecsFromLabelsAndMap(const absl::flat_hash_map<char,std::vector<int>> & label_to_index,const absl::flat_hash_map<char,ShardingSpec> & label_to_sharding_spec)223 GetSpecsFromLabelsAndMap(
224 const absl::flat_hash_map<char, std::vector<int>>& label_to_index,
225 const absl::flat_hash_map<char, ShardingSpec>& label_to_sharding_spec) {
226 int layout_rank = 0;
227 for (const auto& label_and_indices : label_to_index)
228 layout_rank += label_and_indices.second.size();
229
230 std::vector<ShardingSpec> sharding_specs(layout_rank);
231 absl::flat_hash_map<std::string, int> dimension_use_count;
232 absl::flat_hash_set<std::string> dimension_use_set;
233 for (const auto& label_and_indices : label_to_index) {
234 const auto& loc = label_to_sharding_spec.find(label_and_indices.first);
235 if (loc != label_to_sharding_spec.end()) {
236 const ShardingSpec& sharding_spec = loc->second;
237 for (int index : label_and_indices.second)
238 sharding_specs[index] = sharding_spec;
239 dimension_use_count[sharding_spec.sharding_spec()] +=
240 label_and_indices.second.size();
241 } else {
242 for (int index : label_and_indices.second)
243 sharding_specs[index].set_sharding_spec(Layout::kUnshardedDim);
244 }
245 }
246 return std::make_pair(sharding_specs, dimension_use_count);
247 }
248
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)249 StatusOr<llvm::DenseMap<int, Layout>> EinsumSPMDExpander::ComputeLayoutForward(
250 mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
251 if (input_layouts.empty()) return llvm::DenseMap<int, Layout>();
252
253 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
254
255 // Need the mapping of input and output labels from the equation.
256 auto einsum_op = mlir::cast<mlir::TF::EinsumOp>(op);
257 size_t num_inputs = einsum_op.getNumOperands();
258 std::string equation = einsum_op.equation().str();
259 absl::flat_hash_set<char> reduced_dim_labels;
260 std::vector<absl::flat_hash_map<char, std::vector<int>>> input_mappings;
261 absl::flat_hash_map<char, std::vector<int>> output_mapping;
262
263 TF_RETURN_IF_ERROR(ExtractEquationRelations(equation, reduced_dim_labels,
264 input_mappings, output_mapping));
265 if (input_mappings.size() != num_inputs)
266 return errors::InvalidArgument(
267 "Einsum equation ", equation, " has ", input_mappings.size(),
268 " inputs but this op has ", num_inputs, " inputs.");
269
270 // GetLabelToShardingSpec requires two inputs if the einsum equation needs
271 // two inputs. We may only have one layout, so make other replicated. This
272 // will have the same effect as only using the defined layout and using
273 // replicated for all the missing dimensions.
274 std::vector<Layout> layouts;
275 for (int k = 0; k < num_inputs; ++k) {
276 if (input_layouts.find(k) != input_layouts.end()) {
277 layouts.emplace_back(input_layouts.lookup(k));
278 } else {
279 int rank = ValueRank(op->getOperand(k));
280 if (rank < 0) return errors::InvalidArgument("No rank for input ", k);
281 // This case can only happen when there are two inputs. Input 1 - k
282 // is the other input. In this case of the if, input k is missing, so
283 // this means that input 1 - k must be there.
284 layouts.emplace_back(Layout::ReplicatedOnMesh(mesh, rank));
285 }
286 }
287
288 // For each input, identify the mesh dimension
289 TF_ASSIGN_OR_RETURN(
290 auto input_label_to_sharding_spec,
291 GetLabelToShardingSpec(
292 /*replicate_incompatible_dimensions=*/true, layouts, input_mappings));
293 // Compute output layout based on retained mesh dimensions
294 TF_ASSIGN_OR_RETURN(
295 const auto& output_layout,
296 VerifyOrFixLayout(GetSpecsFromLabelsAndMap(output_mapping,
297 input_label_to_sharding_spec),
298 mesh));
299 return llvm::DenseMap<int, Layout>({{0, output_layout}});
300 }
301
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)302 StatusOr<llvm::DenseMap<int, Layout>> EinsumSPMDExpander::ComputeLayoutBackward(
303 mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
304 if (output_layouts.find(0) == output_layouts.end())
305 return llvm::DenseMap<int, Layout>();
306
307 const Layout output_layout = output_layouts.lookup(0);
308 TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
309
310 // Need the mapping of input and output labels from the equation.
311 auto einsum_op = mlir::cast<mlir::TF::EinsumOp>(op);
312 size_t num_inputs = einsum_op.getNumOperands();
313 std::string equation = einsum_op.equation().str();
314 absl::flat_hash_set<char> reduced_dim_labels;
315 std::vector<absl::flat_hash_map<char, std::vector<int>>> input_mappings;
316 absl::flat_hash_map<char, std::vector<int>> output_mapping;
317
318 TF_RETURN_IF_ERROR(ExtractEquationRelations(equation, reduced_dim_labels,
319 input_mappings, output_mapping));
320 if (input_mappings.size() != num_inputs)
321 return errors::InvalidArgument(
322 "Einsum equation ", equation, " has ", input_mappings.size(),
323 " inputs but this op has ", num_inputs, " inputs.");
324
325 // Using the output mapping, construct an equation label to mesh dimension
326 // mapping.
327 TF_ASSIGN_OR_RETURN(auto output_label_to_sharding_spec,
328 GetLabelToShardingSpec(
329 /*replicate_incompatible_dimensions=*/false,
330 {output_layout}, {output_mapping}));
331
332 // Defines a set of labels that could be set to Any. The conditions for an
333 // operand label to be set to any are that 1) is not present in the output
334 // and 2) is not repeated in any operand.
335 absl::flat_hash_set<char> labels_for_any;
336 for (const auto& operand_mapping : input_mappings)
337 for (const auto& label_to_indices : operand_mapping)
338 labels_for_any.insert(label_to_indices.first);
339
340 // Filter repeated labels.
341 for (const auto& operand_mapping : input_mappings)
342 for (const auto& label_to_indices : operand_mapping)
343 if (label_to_indices.second.size() > 1)
344 labels_for_any.erase(label_to_indices.first);
345
346 // Filter labels in output.
347 for (const auto& label_to_indices : output_mapping)
348 labels_for_any.erase(label_to_indices.first);
349
350 llvm::DenseMap<int, Layout> input_layouts(num_inputs);
351
352 // Derive operand sharding specs from output's sharding specs.
353 for (size_t i = 0; i < num_inputs; ++i) {
354 absl::flat_hash_map<char, std::vector<int>> labels_to_indices =
355 input_mappings[i];
356 std::pair<std::vector<ShardingSpec>, absl::flat_hash_map<std::string, int>>
357 sharding_specs_and_dim_count = GetSpecsFromLabelsAndMap(
358 labels_to_indices, output_label_to_sharding_spec);
359
360 std::vector<ShardingSpec> sharding_specs =
361 sharding_specs_and_dim_count.first;
362 absl::flat_hash_map<std::string, int> dim_count =
363 sharding_specs_and_dim_count.second;
364
365 // Flip "unsharded" specs to "any" if they are present in the set.
366 for (const auto& label_to_indices : labels_to_indices) {
367 char label = label_to_indices.first;
368 if (labels_for_any.contains(label)) {
369 int index = label_to_indices.second[0];
370 sharding_specs[index].set_sharding_spec(Layout::kAny);
371 }
372 }
373 TF_ASSIGN_OR_RETURN(
374 const auto& layout,
375 VerifyOrFixLayout(std::make_pair(sharding_specs, dim_count), mesh));
376 input_layouts[i] = layout;
377 }
378
379 return input_layouts;
380 }
381
382 // A few things we don't support, but could:
383 // * "xx->" or "xx->x": Trace like operation where at least one input dimension
384 // for x is sharded. If both are sharded, we can compute the einsum on the
385 // diagonal machines in the mesh and 0s on the off diagonals and then all
386 // the much smaller matrix.
MaybeRelayoutInputs(const std::vector<Layout> & input_layouts,mlir::Operation * op,const Layout & output_layout,absl::flat_hash_set<std::string> & reduce_dims,Layout & einsum_layout,std::vector<mlir::Value> & new_inputs)387 Status EinsumSPMDExpander::MaybeRelayoutInputs(
388 const std::vector<Layout>& input_layouts, mlir::Operation* op,
389 const Layout& output_layout, absl::flat_hash_set<std::string>& reduce_dims,
390 Layout& einsum_layout, std::vector<mlir::Value>& new_inputs) {
391 if (!mlir::isa<mlir::TF::EinsumOp>(op))
392 return errors::InvalidArgument(
393 "called einsum spmd expander but op is not Einsum.");
394
395 mlir::TF::EinsumOp einsum = mlir::cast<mlir::TF::EinsumOp>(op);
396 std::vector<absl::flat_hash_map<char, std::vector<int>>> input_mappings;
397 absl::flat_hash_map<char, std::vector<int>> output_mapping;
398 absl::flat_hash_set<char> contracting_labels;
399 absl::flat_hash_set<char> all_labels;
400 TF_RETURN_IF_ERROR(ExtractEquationRelations(einsum.equation().str(),
401 contracting_labels,
402 input_mappings, output_mapping));
403
404 for (const auto& input_mapping : input_mappings)
405 for (const auto& char_and_positions : input_mapping)
406 all_labels.emplace(char_and_positions.first);
407
408 // We will update this array throughout this function with the following rules
409 // 1. The sharding of a label which is not in the map is unknown.
410 // 2. Once the sharding of label becomes known and is unsharded, we
411 // won't change that.
412 TF_ASSIGN_OR_RETURN(auto input_label_to_sharding_spec,
413 GetLabelToShardingSpec(
414 /*replicate_incompatible_dimensions=*/false,
415 input_layouts, input_mappings));
416
417 TF_ASSIGN_OR_RETURN(const auto output_label_to_sharding_spec,
418 GetLabelToShardingSpec(
419 /*replicate_incompatible_dimensions=*/false,
420 {output_layout}, {output_mapping}));
421
422 for (const char label : all_labels) {
423 if (input_label_to_sharding_spec.contains(label) &&
424 output_label_to_sharding_spec.contains(label) &&
425 !Equal(input_label_to_sharding_spec[label],
426 output_label_to_sharding_spec.find(label)->second))
427 return errors::InvalidArgument(
428 "for label ", label, " input and output layouts are sharded on ",
429 " non-equal dimensions ",
430 input_label_to_sharding_spec[label].sharding_spec(), " and ",
431 output_label_to_sharding_spec.find(label)->second.sharding_spec(),
432 "respectively");
433 }
434
435 // First priority is to ensure that labels which occur at least twice on one
436 // side never get sharded, as we cannot deal with that. This corresponds to
437 // taking a trace on that input, which will require us to be unsharded.
438 for (const auto& input_mapping : input_mappings)
439 for (const auto& char_and_positions : input_mapping)
440 if (char_and_positions.second.size() > 1)
441 input_label_to_sharding_spec[char_and_positions.first]
442 .set_sharding_spec(Layout::kUnshardedDim);
443
444 absl::flat_hash_map<std::string, absl::flat_hash_set<char>>
445 sharding_dim_to_non_contracting_labels;
446 absl::flat_hash_map<std::string, absl::flat_hash_set<char>>
447 sharding_dim_to_contracting_labels;
448 for (const auto& label_and_spec : input_label_to_sharding_spec) {
449 if (Layout::IsShardedSpec(label_and_spec.second)) {
450 if (contracting_labels.contains(label_and_spec.first))
451 sharding_dim_to_contracting_labels[label_and_spec.second
452 .sharding_spec()]
453 .insert(label_and_spec.first);
454 else
455 sharding_dim_to_non_contracting_labels[label_and_spec.second
456 .sharding_spec()]
457 .insert(label_and_spec.first);
458 }
459 }
460
461 // If a non-contracting dimension is sharded in the output and non-sharded
462 // in the input and no other label is sharded on that dimension, then shard
463 // it.
464 // This handles the *,x . x,* -> *,y case and also if batch dimensions are
465 // sharded on the output but not the input.
466 for (const char label : all_labels) {
467 // Note that only sharded labels are in output_label_to_sharding_spec, so
468 // there is no need to check that the spec is sharded.
469 if (!contracting_labels.contains(label) &&
470 output_label_to_sharding_spec.contains(label) &&
471 !input_label_to_sharding_spec.contains(label)) {
472 const ShardingSpec& sharding_spec =
473 output_label_to_sharding_spec.find(label)->second;
474 const std::string& string_spec = sharding_spec.sharding_spec();
475 if (!sharding_dim_to_non_contracting_labels.contains(string_spec) &&
476 !sharding_dim_to_contracting_labels.contains(string_spec)) {
477 input_label_to_sharding_spec[label] = sharding_spec;
478 sharding_dim_to_non_contracting_labels[string_spec].insert(label);
479 }
480 }
481 }
482
483 // Handle the case when two non-contracting dimensions are have the same
484 // sharding spec.
485 // Note that the case of three non-contracting dimensions having the same
486 // sharding spec is impossible. Since there are at most two inputs, at least
487 // one input would have two dimensions with the same sharing spec.
488 // This handles the y,x . x,y -> *,y case.
489 absl::flat_hash_set<std::string> dims_with_multiple_labels;
490 for (const auto& spec_and_labels : sharding_dim_to_non_contracting_labels) {
491 if (spec_and_labels.second.size() > 1) {
492 assert(spec_and_labels.second.size() == 2);
493 dims_with_multiple_labels.insert(spec_and_labels.first);
494 }
495 }
496 for (const auto& dim : dims_with_multiple_labels) {
497 // TODO(bfontain): Update this to pick default label to keep based on shape.
498 char label_to_keep = 0xFF;
499 // Note that all these conditions evaluated in the loop below are mutually
500 // as exclusive as no two labels in the output have the same sharding
501 // spec.
502 // If the no label is found we choose the lexicographically least label to
503 // keep this stable with respect to ordering.
504 for (const char label : sharding_dim_to_non_contracting_labels[dim]) {
505 if (output_label_to_sharding_spec.contains(label) &&
506 output_label_to_sharding_spec.find(label)->second.sharding_spec() ==
507 dim) {
508 label_to_keep = label;
509 break;
510 } else if (label < label_to_keep) {
511 label_to_keep = label;
512 }
513 }
514 for (const char label : sharding_dim_to_non_contracting_labels[dim])
515 if (label != label_to_keep)
516 input_label_to_sharding_spec[label].set_sharding_spec(
517 Layout::kUnshardedDim);
518 sharding_dim_to_non_contracting_labels[dim].clear();
519 sharding_dim_to_non_contracting_labels[dim].insert(label_to_keep);
520 }
521
522 // Handle the case where a non-contracting and contracting dim have the same
523 // sharding spec. For now we always unshard the contracting axis. Note that
524 // this is safe.
525 // This handles the case x,y . *,y -> x,y
526 for (const auto& spec_and_labels : sharding_dim_to_contracting_labels) {
527 if (!spec_and_labels.second.empty() &&
528 !sharding_dim_to_non_contracting_labels[spec_and_labels.first]
529 .empty()) {
530 assert(spec_and_labels.second.size() == 1);
531 assert(sharding_dim_to_non_contracting_labels[spec_and_labels.first]
532 .size() == 1);
533 input_label_to_sharding_spec[*spec_and_labels.second.begin()]
534 .set_sharding_spec(Layout::kUnshardedDim);
535 }
536 }
537
538 // Relayout the inputs
539 mlir::OpBuilder builder(op);
540 new_inputs.resize(input_mappings.size());
541 for (int i = 0; i < input_mappings.size(); ++i) {
542 TF_ASSIGN_OR_RETURN(
543 const Layout new_input_layout,
544 VerifyOrFixLayout(GetSpecsFromLabelsAndMap(
545 input_mappings[i], input_label_to_sharding_spec),
546 output_layout.mesh()));
547
548 TF_ASSIGN_OR_RETURN(
549 new_inputs[i],
550 EmitRelayout(op->getOperand(i), input_layouts[i], new_input_layout));
551 }
552
553 TF_ASSIGN_OR_RETURN(
554 einsum_layout,
555 VerifyOrFixLayout(GetSpecsFromLabelsAndMap(output_mapping,
556 input_label_to_sharding_spec),
557 output_layout.mesh()));
558
559 for (const auto& contracting : contracting_labels)
560 reduce_dims.emplace(
561 input_label_to_sharding_spec[contracting].sharding_spec());
562
563 return OkStatus();
564 }
565
566 } // namespace dtensor
567 } // namespace tensorflow
568