xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/expansions/einsum_spmd_expander.h"
17 
18 #include <string>
19 
20 #include "absl/strings/string_view.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "tensorflow/dtensor/cc/tensor_layout.h"
25 #include "tensorflow/dtensor/mlir/collectives.h"
26 #include "tensorflow/dtensor/mlir/layout_parsing.h"
27 #include "tensorflow/dtensor/mlir/shape_utils.h"
28 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
29 #include "tensorflow/dtensor/mlir/value_utils.h"
30 
31 namespace tensorflow {
32 namespace dtensor {
33 namespace {
34 
Equal(const ShardingSpec & a,const ShardingSpec & b)35 bool Equal(const ShardingSpec& a, const ShardingSpec& b) {
36   return a.sharding_spec() == b.sharding_spec();
37 }
38 
39 }  // namespace
40 
41 // Einsum, like reductions, is implemented as a local operation followed by
42 // an all-reduce over dimensions that have been reduced.
ExpandOp(mlir::Operation * op)43 StatusOr<mlir::Operation*> EinsumSPMDExpander::ExpandOp(mlir::Operation* op) {
44   std::vector<Layout> input_layouts(op->getNumOperands());
45   for (int i = 0; i < op->getNumOperands(); ++i) {
46     TF_ASSIGN_OR_RETURN(auto layout,
47                         ExtractLayoutFromOperand(op->getOperand(i)));
48     if (!layout) return errors::InvalidArgument("missing layout for input ", i);
49     input_layouts[i] = layout.value();
50   }
51   TF_ASSIGN_OR_RETURN(auto output_layout, ExtractSingleLayoutFromOp(op));
52   if (!output_layout)
53     return errors::InvalidArgument("is missing output layout.");
54 
55   std::vector<mlir::Value> new_inputs;
56   Layout layout_after_einsum;
57   absl::flat_hash_set<std::string> reduce_dims;
58   TF_RETURN_IF_ERROR(MaybeRelayoutInputs(input_layouts, op,
59                                          output_layout.value(), reduce_dims,
60                                          layout_after_einsum, new_inputs));
61 
62   mlir::OpBuilder builder(op);
63   mlir::BlockAndValueMapping mapping;
64   for (int i = 0; i < op->getNumOperands(); ++i)
65     mapping.map(op->getOperand(i), new_inputs[i]);
66   mlir::Operation* new_op = builder.clone(*op, mapping);
67   // Note that the output shape of new_op is cloned from op, so we need to
68   // update to the local shape.
69   new_op = InferSPMDExpandedLocalShape(new_op);
70 
71   if (!reduce_dims.empty()) {
72     TF_ASSIGN_OR_RETURN(
73         new_op, EmitAllReduce(builder, layout_after_einsum, reduce_dims, new_op,
74                               kReduceOpAdd));
75   }
76 
77   TF_ASSIGN_OR_RETURN(auto final_output,
78                       EmitRelayout(new_op->getOpResult(0), layout_after_einsum,
79                                    output_layout.value()));
80 
81   op->getOpResult(0).replaceAllUsesWith(final_output);
82   op->erase();
83 
84   return final_output.getDefiningOp();
85 }
86 
87 // TODO(power) -- we use a simplified equation parser here. consider
88 // refactoring einsum_spmd_expander and reusing the TF parser.
89 //
90 // Given the input equation, this has 3 outputs:
91 // reduced_dims: The set of mesh dimesions that we need to all reduce over.
92 // input_mappings: for each equation input, the map from the equation labels
93 //   to the tensor dimension of that label.
94 // output_mapping: as above, but for the equation output.
ExtractEquationRelations(absl::string_view equation,absl::flat_hash_set<char> & reduced_dims,std::vector<absl::flat_hash_map<char,std::vector<int>>> & input_mappings,absl::flat_hash_map<char,std::vector<int>> & output_mapping)95 Status ExtractEquationRelations(
96     absl::string_view equation, absl::flat_hash_set<char>& reduced_dims,
97     std::vector<absl::flat_hash_map<char, std::vector<int>>>& input_mappings,
98     absl::flat_hash_map<char, std::vector<int>>& output_mapping) {
99   std::pair<std::string, std::string> parts = absl::StrSplit(equation, "->");
100   absl::flat_hash_set<char> non_reduced_dims;
101 
102   // Mark kept dimensions from the output.
103   for (const auto& char_and_index : llvm::enumerate(parts.second)) {
104     // TODO(b/172691887): Support Broadcasting for einsum.
105     if (char_and_index.value() == '.')
106       return errors::Unimplemented(
107           "Broadcasting is unimplemented for einsum. Received equation ",
108           equation);
109     non_reduced_dims.insert(char_and_index.value());
110 
111     // Construct the output mapping, note that output is not allowed to have
112     // duplicate labels. This would mean that the datatype of output_mapping
113     // should really be absl::flat_hash_map<char, int>, but having the same
114     // type as the input_mapping keeps GetSpecsFromLabelsAndMap simpler.
115     if (output_mapping.contains(char_and_index.value()))
116       return errors::InvalidArgument("received label ", char_and_index.value(),
117                                      " multiple times in the "
118                                      "output of einsum equation ",
119                                      equation);
120 
121     output_mapping[char_and_index.value()].emplace_back(char_and_index.index());
122   }
123 
124   std::vector<std::string> inputs = absl::StrSplit(parts.first, ',');
125   // Note that the TF einsum op only supports at most 2 inputs. This is slightly
126   // confusing as the tf.einsum interface actually supports > 2 inputs.
127   if (inputs.size() > 2)
128     return errors::InvalidArgument(
129         "einsum only supports at most 2 inputs received equation ", equation,
130         " which has ", inputs.size(), " inputs");
131 
132   input_mappings.resize(inputs.size());
133 
134   // Compute the input mappings and keep track of labels which are reduced.
135   for (int i = 0; i < inputs.size(); ++i) {
136     for (const auto& char_and_index : llvm::enumerate(inputs[i])) {
137       input_mappings[i][char_and_index.value()].emplace_back(
138           char_and_index.index());
139       if (!non_reduced_dims.contains(char_and_index.value()))
140         reduced_dims.insert(char_and_index.value());
141     }
142   }
143 
144   return OkStatus();
145 }
146 
147 // For a set of layouts and mappings from labels to offsets in the layouts,
148 // return a mappings of labels to ShardingSpecs.
149 // If the label appears multiples with different mesh dimensions in the
150 // sharding specs we raise an error if replicate_incompatible_dimensions is
151 // false. Otherwise we treat the dimension as if it were unsharded.
152 // Labels with unsharded dimensions are not recorded in the output.
GetLabelToShardingSpec(bool replicate_incompatible_dimensions,const std::vector<Layout> & layouts,const std::vector<absl::flat_hash_map<char,std::vector<int>>> & mappings)153 StatusOr<absl::flat_hash_map<char, ShardingSpec>> GetLabelToShardingSpec(
154     bool replicate_incompatible_dimensions, const std::vector<Layout>& layouts,
155     const std::vector<absl::flat_hash_map<char, std::vector<int>>>& mappings) {
156   absl::flat_hash_map<char, ShardingSpec> label_to_sharding_spec;
157   absl::flat_hash_set<char> incompatible_labels;
158 
159   // For each mapping, identify the mesh dimension and whether it has been
160   // reduced away.
161   for (int index = 0; index < layouts.size(); ++index) {
162     for (const auto& mapping : mappings[index]) {
163       for (int offset : mapping.second) {
164         if (offset >= layouts[index].rank())
165           return errors::InvalidArgument(
166               llvm::formatv(
167                   "specified einsum equation for operand {0} tried to "
168                   "read layout at offset {1}, but layout is {2} with rank "
169                   "{3}",
170                   index, offset, layouts[index].ToString(),
171                   layouts[index].rank())
172                   .str());
173 
174         const ShardingSpec& sharding_spec = layouts[index].dim(offset);
175 
176         if (label_to_sharding_spec.contains(mapping.first)) {
177           if (Layout::IsShardedSpec(sharding_spec) &&
178               !Equal(label_to_sharding_spec[mapping.first], sharding_spec)) {
179             if (!replicate_incompatible_dimensions)
180               return errors::InvalidArgument(
181                   llvm::formatv(
182                       "incompatible mesh dimensions in equation, label '{0}' "
183                       "is mapped to mesh dimension '{1}' and '{2}'",
184                       mapping.first, sharding_spec.sharding_spec(),
185                       label_to_sharding_spec[mapping.first].sharding_spec())
186                       .str());
187             else
188               incompatible_labels.insert(mapping.first);
189           }
190         } else if (Layout::IsShardedSpec(sharding_spec)) {
191           label_to_sharding_spec[mapping.first] = sharding_spec;
192         }
193       }
194     }
195   }
196 
197   // For labels that had incompatible dimensions, treat them as replicated.
198   // We would need to insert some all to all in the SPMD expansion for these.
199   for (char label : incompatible_labels) label_to_sharding_spec.erase(label);
200 
201   return label_to_sharding_spec;
202 }
203 
204 // The layout we generated may be invalid as the same dimension may be used
205 // multiple times. E.g. ab,bc->ac (i.e. matmul) with a and c sharded over the
206 // same dim. In this case we mark all such dimensions as replicated.
VerifyOrFixLayout(std::pair<std::vector<ShardingSpec>,absl::flat_hash_map<std::string,int>> pair,const Mesh & mesh)207 StatusOr<Layout> VerifyOrFixLayout(
208     std::pair<std::vector<ShardingSpec>, absl::flat_hash_map<std::string, int>>
209         pair,
210     const Mesh& mesh) {
211   std::vector<ShardingSpec> sharding_specs = pair.first;
212   absl::flat_hash_map<std::string, int> dimension_use_count = pair.second;
213   for (int i = 0; i < sharding_specs.size(); ++i)
214     if (Layout::IsShardedSpec(sharding_specs[i]) &&
215         dimension_use_count[sharding_specs[i].sharding_spec()] > 1)
216       sharding_specs[i].set_sharding_spec(Layout::kUnshardedDim);
217   return Layout::GetLayout(sharding_specs, mesh);
218 }
219 
220 // Construct a layout on a given mesh from the label to tensor dimension map
221 // and the label to mesh_dimension map.
222 std::pair<std::vector<ShardingSpec>, absl::flat_hash_map<std::string, int>>
GetSpecsFromLabelsAndMap(const absl::flat_hash_map<char,std::vector<int>> & label_to_index,const absl::flat_hash_map<char,ShardingSpec> & label_to_sharding_spec)223 GetSpecsFromLabelsAndMap(
224     const absl::flat_hash_map<char, std::vector<int>>& label_to_index,
225     const absl::flat_hash_map<char, ShardingSpec>& label_to_sharding_spec) {
226   int layout_rank = 0;
227   for (const auto& label_and_indices : label_to_index)
228     layout_rank += label_and_indices.second.size();
229 
230   std::vector<ShardingSpec> sharding_specs(layout_rank);
231   absl::flat_hash_map<std::string, int> dimension_use_count;
232   absl::flat_hash_set<std::string> dimension_use_set;
233   for (const auto& label_and_indices : label_to_index) {
234     const auto& loc = label_to_sharding_spec.find(label_and_indices.first);
235     if (loc != label_to_sharding_spec.end()) {
236       const ShardingSpec& sharding_spec = loc->second;
237       for (int index : label_and_indices.second)
238         sharding_specs[index] = sharding_spec;
239       dimension_use_count[sharding_spec.sharding_spec()] +=
240           label_and_indices.second.size();
241     } else {
242       for (int index : label_and_indices.second)
243         sharding_specs[index].set_sharding_spec(Layout::kUnshardedDim);
244     }
245   }
246   return std::make_pair(sharding_specs, dimension_use_count);
247 }
248 
ComputeLayoutForward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & input_layouts)249 StatusOr<llvm::DenseMap<int, Layout>> EinsumSPMDExpander::ComputeLayoutForward(
250     mlir::Operation* op, const llvm::DenseMap<int, Layout>& input_layouts) {
251   if (input_layouts.empty()) return llvm::DenseMap<int, Layout>();
252 
253   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
254 
255   // Need the mapping of input and output labels from the equation.
256   auto einsum_op = mlir::cast<mlir::TF::EinsumOp>(op);
257   size_t num_inputs = einsum_op.getNumOperands();
258   std::string equation = einsum_op.equation().str();
259   absl::flat_hash_set<char> reduced_dim_labels;
260   std::vector<absl::flat_hash_map<char, std::vector<int>>> input_mappings;
261   absl::flat_hash_map<char, std::vector<int>> output_mapping;
262 
263   TF_RETURN_IF_ERROR(ExtractEquationRelations(equation, reduced_dim_labels,
264                                               input_mappings, output_mapping));
265   if (input_mappings.size() != num_inputs)
266     return errors::InvalidArgument(
267         "Einsum equation ", equation, " has ", input_mappings.size(),
268         " inputs but this op has ", num_inputs, " inputs.");
269 
270   // GetLabelToShardingSpec requires two inputs if the einsum equation needs
271   // two inputs. We may only have one layout, so make other replicated. This
272   // will have the same effect as only using the defined layout and using
273   // replicated for all the missing dimensions.
274   std::vector<Layout> layouts;
275   for (int k = 0; k < num_inputs; ++k) {
276     if (input_layouts.find(k) != input_layouts.end()) {
277       layouts.emplace_back(input_layouts.lookup(k));
278     } else {
279       int rank = ValueRank(op->getOperand(k));
280       if (rank < 0) return errors::InvalidArgument("No rank for input ", k);
281       // This case can only happen when there are two inputs. Input 1 - k
282       // is the other input. In this case of the if, input k is missing, so
283       // this means that input 1 - k must be there.
284       layouts.emplace_back(Layout::ReplicatedOnMesh(mesh, rank));
285     }
286   }
287 
288   // For each input, identify the mesh dimension
289   TF_ASSIGN_OR_RETURN(
290       auto input_label_to_sharding_spec,
291       GetLabelToShardingSpec(
292           /*replicate_incompatible_dimensions=*/true, layouts, input_mappings));
293   // Compute output layout based on retained mesh dimensions
294   TF_ASSIGN_OR_RETURN(
295       const auto& output_layout,
296       VerifyOrFixLayout(GetSpecsFromLabelsAndMap(output_mapping,
297                                                  input_label_to_sharding_spec),
298                         mesh));
299   return llvm::DenseMap<int, Layout>({{0, output_layout}});
300 }
301 
ComputeLayoutBackward(mlir::Operation * op,const llvm::DenseMap<int,Layout> & output_layouts)302 StatusOr<llvm::DenseMap<int, Layout>> EinsumSPMDExpander::ComputeLayoutBackward(
303     mlir::Operation* op, const llvm::DenseMap<int, Layout>& output_layouts) {
304   if (output_layouts.find(0) == output_layouts.end())
305     return llvm::DenseMap<int, Layout>();
306 
307   const Layout output_layout = output_layouts.lookup(0);
308   TF_ASSIGN_OR_RETURN(auto mesh, ExtractDeviceMeshEnclosingCluster(op));
309 
310   // Need the mapping of input and output labels from the equation.
311   auto einsum_op = mlir::cast<mlir::TF::EinsumOp>(op);
312   size_t num_inputs = einsum_op.getNumOperands();
313   std::string equation = einsum_op.equation().str();
314   absl::flat_hash_set<char> reduced_dim_labels;
315   std::vector<absl::flat_hash_map<char, std::vector<int>>> input_mappings;
316   absl::flat_hash_map<char, std::vector<int>> output_mapping;
317 
318   TF_RETURN_IF_ERROR(ExtractEquationRelations(equation, reduced_dim_labels,
319                                               input_mappings, output_mapping));
320   if (input_mappings.size() != num_inputs)
321     return errors::InvalidArgument(
322         "Einsum equation ", equation, " has ", input_mappings.size(),
323         " inputs but this op has ", num_inputs, " inputs.");
324 
325   // Using the output mapping, construct an equation label to mesh dimension
326   // mapping.
327   TF_ASSIGN_OR_RETURN(auto output_label_to_sharding_spec,
328                       GetLabelToShardingSpec(
329                           /*replicate_incompatible_dimensions=*/false,
330                           {output_layout}, {output_mapping}));
331 
332   // Defines a set of labels that could be set to Any. The conditions for an
333   // operand label to be set to any are that 1) is not present in the output
334   // and 2) is not repeated in any operand.
335   absl::flat_hash_set<char> labels_for_any;
336   for (const auto& operand_mapping : input_mappings)
337     for (const auto& label_to_indices : operand_mapping)
338       labels_for_any.insert(label_to_indices.first);
339 
340   // Filter repeated labels.
341   for (const auto& operand_mapping : input_mappings)
342     for (const auto& label_to_indices : operand_mapping)
343       if (label_to_indices.second.size() > 1)
344         labels_for_any.erase(label_to_indices.first);
345 
346   // Filter labels in output.
347   for (const auto& label_to_indices : output_mapping)
348     labels_for_any.erase(label_to_indices.first);
349 
350   llvm::DenseMap<int, Layout> input_layouts(num_inputs);
351 
352   // Derive operand sharding specs from output's sharding specs.
353   for (size_t i = 0; i < num_inputs; ++i) {
354     absl::flat_hash_map<char, std::vector<int>> labels_to_indices =
355         input_mappings[i];
356     std::pair<std::vector<ShardingSpec>, absl::flat_hash_map<std::string, int>>
357         sharding_specs_and_dim_count = GetSpecsFromLabelsAndMap(
358             labels_to_indices, output_label_to_sharding_spec);
359 
360     std::vector<ShardingSpec> sharding_specs =
361         sharding_specs_and_dim_count.first;
362     absl::flat_hash_map<std::string, int> dim_count =
363         sharding_specs_and_dim_count.second;
364 
365     // Flip "unsharded" specs to "any" if they are present in the set.
366     for (const auto& label_to_indices : labels_to_indices) {
367       char label = label_to_indices.first;
368       if (labels_for_any.contains(label)) {
369         int index = label_to_indices.second[0];
370         sharding_specs[index].set_sharding_spec(Layout::kAny);
371       }
372     }
373     TF_ASSIGN_OR_RETURN(
374         const auto& layout,
375         VerifyOrFixLayout(std::make_pair(sharding_specs, dim_count), mesh));
376     input_layouts[i] = layout;
377   }
378 
379   return input_layouts;
380 }
381 
382 // A few things we don't support, but could:
383 // * "xx->" or "xx->x": Trace like operation where at least one input dimension
384 //   for x is sharded. If both are sharded, we can compute the einsum on the
385 //   diagonal machines in the mesh and 0s on the off diagonals and then all
386 //   the much smaller matrix.
MaybeRelayoutInputs(const std::vector<Layout> & input_layouts,mlir::Operation * op,const Layout & output_layout,absl::flat_hash_set<std::string> & reduce_dims,Layout & einsum_layout,std::vector<mlir::Value> & new_inputs)387 Status EinsumSPMDExpander::MaybeRelayoutInputs(
388     const std::vector<Layout>& input_layouts, mlir::Operation* op,
389     const Layout& output_layout, absl::flat_hash_set<std::string>& reduce_dims,
390     Layout& einsum_layout, std::vector<mlir::Value>& new_inputs) {
391   if (!mlir::isa<mlir::TF::EinsumOp>(op))
392     return errors::InvalidArgument(
393         "called einsum spmd expander but op is not Einsum.");
394 
395   mlir::TF::EinsumOp einsum = mlir::cast<mlir::TF::EinsumOp>(op);
396   std::vector<absl::flat_hash_map<char, std::vector<int>>> input_mappings;
397   absl::flat_hash_map<char, std::vector<int>> output_mapping;
398   absl::flat_hash_set<char> contracting_labels;
399   absl::flat_hash_set<char> all_labels;
400   TF_RETURN_IF_ERROR(ExtractEquationRelations(einsum.equation().str(),
401                                               contracting_labels,
402                                               input_mappings, output_mapping));
403 
404   for (const auto& input_mapping : input_mappings)
405     for (const auto& char_and_positions : input_mapping)
406       all_labels.emplace(char_and_positions.first);
407 
408   // We will update this array throughout this function with the following rules
409   // 1. The sharding of a label which is not in the map is unknown.
410   // 2. Once the sharding of label becomes known and is unsharded, we
411   //    won't change that.
412   TF_ASSIGN_OR_RETURN(auto input_label_to_sharding_spec,
413                       GetLabelToShardingSpec(
414                           /*replicate_incompatible_dimensions=*/false,
415                           input_layouts, input_mappings));
416 
417   TF_ASSIGN_OR_RETURN(const auto output_label_to_sharding_spec,
418                       GetLabelToShardingSpec(
419                           /*replicate_incompatible_dimensions=*/false,
420                           {output_layout}, {output_mapping}));
421 
422   for (const char label : all_labels) {
423     if (input_label_to_sharding_spec.contains(label) &&
424         output_label_to_sharding_spec.contains(label) &&
425         !Equal(input_label_to_sharding_spec[label],
426                output_label_to_sharding_spec.find(label)->second))
427       return errors::InvalidArgument(
428           "for label ", label, " input and output layouts are sharded on ",
429           " non-equal dimensions ",
430           input_label_to_sharding_spec[label].sharding_spec(), " and ",
431           output_label_to_sharding_spec.find(label)->second.sharding_spec(),
432           "respectively");
433   }
434 
435   // First priority is to ensure that labels which occur at least twice on one
436   // side never get sharded, as we cannot deal with that. This corresponds to
437   // taking a trace on that input, which will require us to be unsharded.
438   for (const auto& input_mapping : input_mappings)
439     for (const auto& char_and_positions : input_mapping)
440       if (char_and_positions.second.size() > 1)
441         input_label_to_sharding_spec[char_and_positions.first]
442             .set_sharding_spec(Layout::kUnshardedDim);
443 
444   absl::flat_hash_map<std::string, absl::flat_hash_set<char>>
445       sharding_dim_to_non_contracting_labels;
446   absl::flat_hash_map<std::string, absl::flat_hash_set<char>>
447       sharding_dim_to_contracting_labels;
448   for (const auto& label_and_spec : input_label_to_sharding_spec) {
449     if (Layout::IsShardedSpec(label_and_spec.second)) {
450       if (contracting_labels.contains(label_and_spec.first))
451         sharding_dim_to_contracting_labels[label_and_spec.second
452                                                .sharding_spec()]
453             .insert(label_and_spec.first);
454       else
455         sharding_dim_to_non_contracting_labels[label_and_spec.second
456                                                    .sharding_spec()]
457             .insert(label_and_spec.first);
458     }
459   }
460 
461   // If a non-contracting dimension is sharded in the output and non-sharded
462   // in the input and no other label is sharded on that dimension, then shard
463   // it.
464   // This handles the *,x . x,* -> *,y case and also if batch dimensions are
465   // sharded on the output but not the input.
466   for (const char label : all_labels) {
467     // Note that only sharded labels are in output_label_to_sharding_spec, so
468     // there is no need to check that the spec is sharded.
469     if (!contracting_labels.contains(label) &&
470         output_label_to_sharding_spec.contains(label) &&
471         !input_label_to_sharding_spec.contains(label)) {
472       const ShardingSpec& sharding_spec =
473           output_label_to_sharding_spec.find(label)->second;
474       const std::string& string_spec = sharding_spec.sharding_spec();
475       if (!sharding_dim_to_non_contracting_labels.contains(string_spec) &&
476           !sharding_dim_to_contracting_labels.contains(string_spec)) {
477         input_label_to_sharding_spec[label] = sharding_spec;
478         sharding_dim_to_non_contracting_labels[string_spec].insert(label);
479       }
480     }
481   }
482 
483   // Handle the case when two non-contracting dimensions are have the same
484   // sharding spec.
485   // Note that the case of three non-contracting dimensions having the same
486   // sharding spec is impossible. Since there are at most two inputs, at least
487   // one input would have two dimensions with the same sharing spec.
488   // This handles the y,x . x,y -> *,y case.
489   absl::flat_hash_set<std::string> dims_with_multiple_labels;
490   for (const auto& spec_and_labels : sharding_dim_to_non_contracting_labels) {
491     if (spec_and_labels.second.size() > 1) {
492       assert(spec_and_labels.second.size() == 2);
493       dims_with_multiple_labels.insert(spec_and_labels.first);
494     }
495   }
496   for (const auto& dim : dims_with_multiple_labels) {
497     // TODO(bfontain): Update this to pick default label to keep based on shape.
498     char label_to_keep = 0xFF;
499     // Note that all these conditions evaluated in the loop below are mutually
500     // as exclusive as no two labels in the output have the same sharding
501     // spec.
502     // If the no label is found we choose the lexicographically least label to
503     // keep this stable with respect to ordering.
504     for (const char label : sharding_dim_to_non_contracting_labels[dim]) {
505       if (output_label_to_sharding_spec.contains(label) &&
506           output_label_to_sharding_spec.find(label)->second.sharding_spec() ==
507               dim) {
508         label_to_keep = label;
509         break;
510       } else if (label < label_to_keep) {
511         label_to_keep = label;
512       }
513     }
514     for (const char label : sharding_dim_to_non_contracting_labels[dim])
515       if (label != label_to_keep)
516         input_label_to_sharding_spec[label].set_sharding_spec(
517             Layout::kUnshardedDim);
518     sharding_dim_to_non_contracting_labels[dim].clear();
519     sharding_dim_to_non_contracting_labels[dim].insert(label_to_keep);
520   }
521 
522   // Handle the case where a non-contracting and contracting dim have the same
523   // sharding spec. For now we always unshard the contracting axis. Note that
524   // this is safe.
525   // This handles the case x,y . *,y -> x,y
526   for (const auto& spec_and_labels : sharding_dim_to_contracting_labels) {
527     if (!spec_and_labels.second.empty() &&
528         !sharding_dim_to_non_contracting_labels[spec_and_labels.first]
529              .empty()) {
530       assert(spec_and_labels.second.size() == 1);
531       assert(sharding_dim_to_non_contracting_labels[spec_and_labels.first]
532                  .size() == 1);
533       input_label_to_sharding_spec[*spec_and_labels.second.begin()]
534           .set_sharding_spec(Layout::kUnshardedDim);
535     }
536   }
537 
538   // Relayout the inputs
539   mlir::OpBuilder builder(op);
540   new_inputs.resize(input_mappings.size());
541   for (int i = 0; i < input_mappings.size(); ++i) {
542     TF_ASSIGN_OR_RETURN(
543         const Layout new_input_layout,
544         VerifyOrFixLayout(GetSpecsFromLabelsAndMap(
545                               input_mappings[i], input_label_to_sharding_spec),
546                           output_layout.mesh()));
547 
548     TF_ASSIGN_OR_RETURN(
549         new_inputs[i],
550         EmitRelayout(op->getOperand(i), input_layouts[i], new_input_layout));
551   }
552 
553   TF_ASSIGN_OR_RETURN(
554       einsum_layout,
555       VerifyOrFixLayout(GetSpecsFromLabelsAndMap(output_mapping,
556                                                  input_label_to_sharding_spec),
557                         output_layout.mesh()));
558 
559   for (const auto& contracting : contracting_labels)
560     reduce_dims.emplace(
561         input_label_to_sharding_spec[contracting].sharding_spec());
562 
563   return OkStatus();
564 }
565 
566 }  // namespace dtensor
567 }  // namespace tensorflow
568