xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_sharding_metadata.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 
27 namespace xla {
28 
29 namespace {
30 
31 // AssignmentKind and kUnassignedDevice are used during tuple domain sharding
32 // propagation in order to distinguish among three cases:
33 // kUnassigned: no assignment has occurred
34 // kAssigned: at least an assignment has occurred
35 // kConflict: no assignment has occurred because of conflicting propagations,
36 // which occurs when multiple users of an instruction have different
37 // shardings.
38 enum class AssignmentKind { kUnassigned, kAssigned, kConflict };
39 
40 // kUnassignedDevice can only be assigned to tuple leaf shardings to indicate
41 // absence of sharding information for that particular sub-sharding during
42 // sharding propagation. It is used to be able to express tuple shardings with
43 // partial information. At the end of the propagation the sharding of
44 // tuple-shaped instructions using kUnassignedDevice's is cleared.
45 // TODO(b/112883246): Centralized enum of reserved devices.
46 constexpr int64_t kUnassignedDevice = -2;
47 
48 struct PassThrough {
PassThroughxla::__anon55df75cf0111::PassThrough49   PassThrough(HloInstruction* user, HloInstruction* operand)
50       : user(user), operand(operand) {}
51 
52   HloInstruction* user = nullptr;
53   HloInstruction* operand = nullptr;
54 };
55 
SetSingleSharding(HloInstruction * instruction,const HloSharding & sharding)56 void SetSingleSharding(HloInstruction* instruction,
57                        const HloSharding& sharding) {
58   VLOG(4) << "  " << instruction->name() << " to " << sharding;
59   instruction->set_single_sharding(sharding);
60 }
61 
ShardingMatches(const HloSharding & sharding1,const HloSharding & sharding2)62 bool ShardingMatches(const HloSharding& sharding1,
63                      const HloSharding& sharding2) {
64   auto single_sharding1 = sharding1.ExtractSingleSharding();
65   if (single_sharding1) {
66     auto single_sharding2 = sharding2.ExtractSingleSharding();
67     if (single_sharding2) {
68       return *single_sharding1 == single_sharding2;
69     }
70   }
71   // Anything which is not unique across all elements, gets a full sharding
72   // compare.
73   return sharding1 == sharding2;
74 }
75 
76 // When we create domains, they are never "empty", where with empty we mean
77 // that a kDomain instruction has as operand another kDomain instruction of the
78 // same kind.
79 // But when the HLO optimizations are run, empty domains can be created.
80 // For example:
81 //
82 //  Domain(device=None, device=0) ->
83 //    Tuple(device=0) ->
84 //      GTE(device=0) ->
85 //        Domain(device=0, device=None)
86 //
87 // In that case the tuple simplifier could create something like:
88 //
89 //  Domain(device=None, device=0) -> Domain(device=0, device=None)
90 //
91 // Which is a so called empty domain.
92 // In the case above, crossing an empty domain which was transiting through
93 // device 0, requires the normalization phase to fixup the empty domain by
94 // adding back a Tuple+GTE pair with the proper device.
95 // One particular case where this can create problems is the result of the
96 // entry computation, where the GTE assignments are used by TF to tell the
97 // XLA where the results should be sent.
LocatePassThroughDomainLinks(const DomainMetadata::Domain & domain)98 std::vector<PassThrough> LocatePassThroughDomainLinks(
99     const DomainMetadata::Domain& domain) {
100   std::vector<PassThrough> pass_through;
101   for (HloInstruction* instruction : domain.enter_domains) {
102     CHECK(instruction->opcode() == HloOpcode::kDomain)
103         << "Instruction is not a kDomain: " << instruction->ToString();
104     for (HloInstruction* user : instruction->users()) {
105       if (user->opcode() == HloOpcode::kDomain &&
106           domain.exit_domains.contains(user)) {
107         pass_through.emplace_back(user, instruction);
108         VLOG(2) << "Found passthrough domain link:";
109         VLOG(2) << "  " << user->ToString();
110         VLOG(2) << "  " << instruction->ToString();
111       }
112     }
113     if (instruction == instruction->parent()->root_instruction()) {
114       pass_through.emplace_back(nullptr, instruction);
115       VLOG(2) << "Found passthrough domain link:";
116       VLOG(2) << "  <root>";
117       VLOG(2) << "  " << instruction->ToString();
118     }
119   }
120   return pass_through;
121 }
122 
FixupPassThroughDomainLinks(const DomainMetadata::Domain & domain,const HloSharding & sharding)123 Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
124                                    const HloSharding& sharding) {
125   for (auto& pass_through : LocatePassThroughDomainLinks(domain)) {
126     HloInstruction* tuple = pass_through.operand->parent()->AddInstruction(
127         HloInstruction::CreateTuple({pass_through.operand}));
128     HloInstruction* gte = pass_through.operand->parent()->AddInstruction(
129         HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
130                                               tuple, 0));
131     gte->set_sharding(sharding);
132     if (pass_through.user != nullptr) {
133       TF_RETURN_IF_ERROR(
134           pass_through.operand->ReplaceUseWith(pass_through.user, gte));
135     } else {
136       pass_through.operand->parent()->set_root_instruction(gte);
137     }
138   }
139   return OkStatus();
140 }
141 
142 // For tuple shardings if every element have the same sharsing then we want to
143 // treat them as single element sharsings to insert less domain separation as a
144 // domain can prevent some optimizations and we want to minimize that from
145 // happening.
CloneShardingForDomain(std::shared_ptr<const HloSharding> sharding)146 std::shared_ptr<const HloSharding> CloneShardingForDomain(
147     std::shared_ptr<const HloSharding> sharding) {
148   auto single_sharding = sharding->ExtractSingleSharding();
149   if (!single_sharding) {
150     return sharding;
151   }
152   return std::make_shared<const HloSharding>(*single_sharding);
153 }
154 
ApplyDomainSingleSharding(const DomainMetadata::Domain & domain,const HloSharding & sharding)155 Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
156                                  const HloSharding& sharding) {
157   VLOG(4) << "Applying " << sharding << " sharding";
158   for (HloInstruction* instruction : domain.instructions) {
159     // We only change instructions without sharding, since otherwise we might
160     // mess up with eventual HLO passes which has knowledge of it.
161     if (!instruction->has_sharding()) {
162       SetSingleSharding(instruction, sharding);
163     } else {
164       VLOG(4) << "  " << instruction->name() << " already has sharding "
165               << instruction->sharding();
166     }
167   }
168   return OkStatus();
169 }
170 
171 // Return the ShapeTree<HloSharding> of the user argument. The user argument
172 // is assumed to be a user of the instruction argument.
173 // If user is a tuple instruction, return the tuple subsharding corresponding to
174 // the operand matching the instruction argument, because that is the
175 // subsharding corresponding to instruction.
GetShardingTreeFromUser(const HloInstruction & instruction,const HloInstruction & user)176 StatusOr<ShapeTree<HloSharding>> GetShardingTreeFromUser(
177     const HloInstruction& instruction, const HloInstruction& user) {
178   if (user.opcode() == HloOpcode::kTuple) {
179     return user.sharding()
180         .GetSubSharding(user.shape(), {user.operand_index(&instruction)})
181         .AsShapeTree(instruction.shape());
182   }
183   return user.sharding().AsShapeTree(user.shape());
184 }
185 
186 // Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice)
187 // then no assignment is made. Therefore kUnassignedDevice is never propagated.
188 // kConflict is returned if lhs is already assigned and rhs is assigned to a
189 // different device.
AssignLeafSharding(HloSharding * lhs,const HloSharding & rhs)190 StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs,
191                                             const HloSharding& rhs) {
192   TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple());
193   if (rhs.UsesDevice(kUnassignedDevice)) {
194     return AssignmentKind::kUnassigned;
195   }
196   if (lhs->UsesDevice(kUnassignedDevice)) {
197     *lhs = rhs;
198     return AssignmentKind::kAssigned;
199   }
200   return lhs->UniqueDevice() != rhs.UniqueDevice()
201              ? AssignmentKind::kConflict
202              : AssignmentKind::kUnassigned;
203 }
204 
205 // Assigns the whole rhs tree to lhs_tree, starting at lhs_it.
206 // In case of conflicting assignment AssignmentKind::kConflict is returned. In
207 // this case lhs_tree is partially assigned, up to the conflicting leaf. It is
208 // up to the caller to discard the partial assignment in case of conflict.
AssignTreeSharding(ShapeTree<HloSharding> * lhs_tree,ShapeTree<HloSharding>::iterator lhs_it,const ShapeTree<HloSharding> & rhs_tree)209 StatusOr<AssignmentKind> AssignTreeSharding(
210     ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it,
211     const ShapeTree<HloSharding>& rhs_tree) {
212   AssignmentKind assigned = AssignmentKind::kUnassigned;
213   auto rhs_it = rhs_tree.begin();
214   for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end();
215        ++lhs_it, ++rhs_it) {
216     // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it)
217     if (rhs_tree.IsLeaf(rhs_it->first)) {
218       TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first));
219       TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned,
220                           AssignLeafSharding(&lhs_it->second, rhs_it->second));
221       if (sub_assigned == AssignmentKind::kConflict) {
222         // In case of conflict we return conflict to the caller. At this point
223         // partial assignments to lhs_tree may have been made already. It is up
224         // to the caller to discard the partial assignment in case of conflict.
225         return AssignmentKind::kConflict;
226       } else if (sub_assigned == AssignmentKind::kAssigned) {
227         assigned = sub_assigned;
228       }
229     }
230   }
231   TF_RET_CHECK(rhs_it == rhs_tree.end());
232   return assigned;
233 }
234 
ApplyShardingFromUsers(HloInstruction * instruction,const DomainMetadata::Domain & domain,const HloSharding & domain_sharding)235 StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
236                                       const DomainMetadata::Domain& domain,
237                                       const HloSharding& domain_sharding) {
238   if (instruction->users().empty()) {
239     // No sharding from users, use domain_sharding, after checking
240     // compatibility.
241     TF_RET_CHECK(instruction->shape().IsTuple() &&
242                  ShapeUtil::GetLeafCount(instruction->shape()) ==
243                      domain_sharding.tuple_elements().size());
244     instruction->set_sharding(domain_sharding);
245     return true;
246   }
247   AssignmentKind assigned = AssignmentKind::kUnassigned;
248   // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple
249   // subshardings can result in a final sharding assignment containing
250   // kUnassignedDevice leaves, in case some tuple indexes are not used, or are
251   // used by users that don't have a sharding.
252   // Non-tuple shardings are either assigned to a real sharding, or are not
253   // assigned at all. As such they will never get assigned to kUnassignedDevice.
254   // In any case, kUnassignedDevice is never propagated, from the implementation
255   // of AssignLeafSharding.
256   ShapeTree<HloSharding> sharding_tree(
257       instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
258   for (HloInstruction* user : instruction->users()) {
259     if (user->opcode() == HloOpcode::kDomain &&
260         domain.exit_domains.contains(user)) {
261       // If a user is a domain and it is registered in the domain exits, then
262       // the instruction sharding is taken directly from the domain, and no
263       // further users need to be visited.
264       instruction->set_sharding(domain_sharding);
265       return true;
266     }
267     if (!user->has_sharding()) {
268       continue;
269     }
270     AssignmentKind sub_assigned = AssignmentKind::kUnassigned;
271     TF_ASSIGN_OR_RETURN(ShapeTree<HloSharding> user_sharding_tree,
272                         GetShardingTreeFromUser(*instruction, *user));
273     if (instruction->shape().IsTuple()) {
274       // For tuple-shaped instructions collect individual tuple subshardings
275       // from the uses, and then combine them into the tuple sharding.
276       // If the user is a GTE its sharding concerns only the subtree of
277       // sharding_tree at index user->tuple_index, otherwise the whole
278       // sharding_tree is affected.
279       ShapeTree<HloSharding>::iterator sharding_tree_begin =
280           user->opcode() == HloOpcode::kGetTupleElement
281               ? sharding_tree.find({user->tuple_index()})
282               : sharding_tree.begin();
283       TF_ASSIGN_OR_RETURN(
284           sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin,
285                                            user_sharding_tree));
286     } else {
287       // Non-tuple shape: assign common users sharding.
288       TF_RET_CHECK(user_sharding_tree.leaf_count() == 1)
289           << "Expected non-tuple user sharding";
290       TF_ASSIGN_OR_RETURN(
291           sub_assigned,
292           AssignTreeSharding(&sharding_tree, sharding_tree.begin(),
293                              user_sharding_tree));
294     }
295 
296     if (sub_assigned == AssignmentKind::kConflict) {
297       // In case of conflict we don't assign any sharding.
298       return false;
299     } else if (sub_assigned == AssignmentKind::kAssigned) {
300       assigned = sub_assigned;
301     }
302   }
303 
304   if (assigned == AssignmentKind::kAssigned) {
305     if (instruction->shape().IsTuple()) {
306       instruction->set_sharding(HloSharding::Tuple(sharding_tree));
307     } else {
308       TF_RET_CHECK(sharding_tree.leaf_count() == 1);
309       instruction->set_sharding(sharding_tree.leaf_begin()->second);
310     }
311     return true;
312   }
313   return false;
314 }
315 
316 // Tries to propagate the sharding information into the instructions that are
317 // part of the domain, in a reverse post order manner (users propagate to
318 // instruction).
ApplyDomainShardingPass(const DomainMetadata::Domain & domain,const HloSharding & domain_sharding)319 StatusOr<int64_t> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
320                                           const HloSharding& domain_sharding) {
321   int64_t assigned = 0;
322   // domain.instructions are ordered in a post-order manner. As we do
323   // user->operand propagation we process instructions in reverse order. In so
324   // doing we are guaranteed to process all users before their operands.
325   for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend();
326        ++it) {
327     HloInstruction* instruction = *it;
328     if (instruction->has_sharding()) {
329       continue;
330     }
331     // Take the sharding from the users.
332     TF_ASSIGN_OR_RETURN(
333         bool instruction_assigned,
334         ApplyShardingFromUsers(instruction, domain, domain_sharding));
335     if (instruction_assigned) {
336       ++assigned;
337       VLOG(4) << "  " << instruction->name() << " to sharding "
338               << instruction->sharding();
339     }
340   }
341   return assigned;
342 }
343 
ApplyDomainSharding(const DomainMetadata::Domain & domain,const HloSharding & sharding)344 Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
345                            const HloSharding& sharding) {
346   // None of the external normalizers handled the domain sharding, try to see
347   // whether this is a single sharding first.
348   auto single_sharding = sharding.ExtractSingleSharding();
349   if (single_sharding) {
350     // Shortcut the simple case. We have a unique sharding, so we call
351     // the ApplyDomainSingleSharding() API which will apply array or tuple
352     // shaped sharding to the domain instructions.
353     return ApplyDomainSingleSharding(domain, *single_sharding);
354   }
355   VLOG(1) << "Assigning non-trivial sharding " << sharding;
356   TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status());
357 
358   int64_t unassigned = 0;
359   for (HloInstruction* instruction : domain.instructions) {
360     if (!instruction->has_sharding()) {
361       LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
362       ++unassigned;
363     } else {
364       // Un-set sharding of tuples whose sub-shardings are assigned to
365       // kUnassignedDevice. Indeed in case of doubt it is better to leave the
366       // entire tuple unassigned, and let the device placer decide for it.
367       // Do not clear the tuple sharding when the instruction is kParameter. The
368       // sharding of the tuple might not be able to reconstructed if its users
369       // are removed during DCE.
370       if (instruction->sharding().UsesDevice(kUnassignedDevice) &&
371           instruction->opcode() != HloOpcode::kParameter) {
372         TF_RET_CHECK(instruction->shape().IsTuple())
373             << "Only tuples can have kUnassignedDevice sub shardings";
374         instruction->clear_sharding();
375       }
376     }
377   }
378   // Should we error out if unassigned > 0?
379   return OkStatus();
380 }
381 
ExtractOriginalCommonSharding(absl::Span<HloInstruction * const> instructions)382 StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
383     absl::Span<HloInstruction* const> instructions) {
384   // If we are here, all the instructions being passed had the same sharding
385   // (or no sharding), by the means of the ShardingMatches() API.
386   // As such, no kDomain was inserted, and here we are asked to extract the
387   // original common sharding.
388   // All the instructions passed to this API are part of the same computation.
389   std::shared_ptr<const HloSharding> sharding;
390   for (HloInstruction* instruction : instructions) {
391     if (instruction->has_sharding()) {
392       if (sharding == nullptr) {
393         sharding = instruction->sharding_ptr();
394       } else {
395         TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
396             << "Sharding " << *sharding << " does not match the one in "
397             << instruction->ToString();
398       }
399     }
400   }
401   if (sharding == nullptr) {
402     return std::shared_ptr<const HloSharding>();
403   }
404   VLOG(4) << "Extracted sharding is " << *sharding;
405   return CloneShardingForDomain(sharding);
406 }
407 
408 }  // namespace
409 
Clone() const410 std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
411   std::unique_ptr<HloSharding> sharding;
412   if (sharding_ != nullptr) {
413     sharding = std::make_unique<HloSharding>(*sharding_);
414   }
415   return std::make_unique<ShardingMetadata>(std::move(sharding));
416 }
417 
Matches(const DomainMetadata & other) const418 bool ShardingMetadata::Matches(const DomainMetadata& other) const {
419   const ShardingMetadata* other_ptr =
420       dynamic_cast<const ShardingMetadata*>(&other);
421   if (other_ptr == nullptr) {
422     // If other is not a ShardingMetadata, then it is clearly a no match.
423     return false;
424   }
425   if (sharding_ == nullptr) {
426     return other_ptr->sharding_ == nullptr;
427   }
428   return other_ptr->sharding_ != nullptr
429              ? ShardingMatches(*sharding_, *other_ptr->sharding_)
430              : false;
431 }
432 
ToString() const433 std::string ShardingMetadata::ToString() const {
434   return sharding_ != nullptr ? sharding_->ToString() : "{}";
435 }
436 
437 /*static*/ StatusOr<const ShardingMetadata*>
ToShardingMetadata(const DomainMetadata * metadata)438 ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) {
439   if (metadata->Kind() != ShardingMetadata::KindName()) {
440     return Status(
441         tensorflow::error::INVALID_ARGUMENT,
442         "ShardingMetadata normalizer called with incorrect domain metadata");
443   }
444   return static_cast<const ShardingMetadata*>(metadata);
445 }
446 
NormalizeShardingDomain(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)447 Status ShardingMetadata::NormalizeShardingDomain(
448     const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
449   if (metadata != nullptr) {
450     TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
451                         ToShardingMetadata(metadata));
452     const HloSharding* sharding = sharding_metadata->sharding();
453     if (sharding != nullptr) {
454       VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":";
455       TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
456       TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
457     }
458   } else {
459     TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
460                         ExtractOriginalCommonSharding(domain.instructions));
461     if (sharding != nullptr) {
462       VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
463       TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
464     } else {
465       VLOG(1) << "Unable to find common sharding";
466     }
467   }
468   return OkStatus();
469 }
470 
471 // Creates a kDomain instruction to be placed between instruction and operand.
472 // The kDomain instruction will be created only if the sharding differ between
473 // the instruction and the operand.
operator ()(HloInstruction * instruction,HloInstruction * root,HloInstruction * operand)474 HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
475                                                   HloInstruction* root,
476                                                   HloInstruction* operand) {
477   auto instruction_sharding = instruction->sharding_ptr();
478   auto root_sharding = root->sharding_ptr();
479   // No need for domain if they both have no sharding.
480   if (instruction_sharding == nullptr && root_sharding == nullptr) {
481     return nullptr;
482   }
483   // No need for domain if they match.
484   if (instruction_sharding != nullptr && root_sharding != nullptr &&
485       ShardingMatches(*instruction_sharding, *root_sharding)) {
486     return nullptr;
487   }
488 
489   if (instruction_sharding != nullptr) {
490     instruction_sharding = CloneShardingForDomain(instruction_sharding);
491   }
492   if (root_sharding != nullptr) {
493     root_sharding = CloneShardingForDomain(root_sharding);
494   }
495 
496   auto it = domain_cse_map_.find({operand, instruction_sharding});
497   if (it != domain_cse_map_.end()) {
498     return it->second;
499   }
500 
501   VLOG(3) << "Creating domain:";
502   VLOG(3) << "  Instruction: " << instruction->name();
503   VLOG(3) << "  Operand: " << operand->name();
504   VLOG(3) << "    User side sharding: "
505           << (instruction_sharding != nullptr ? instruction_sharding->ToString()
506                                               : "None");
507   VLOG(3) << "    Operand side sharding: "
508           << (root_sharding != nullptr ? root_sharding->ToString() : "None");
509 
510   HloInstruction* domain =
511       operand->parent()->AddInstruction(HloInstruction::CreateDomain(
512           operand->shape(), operand,
513           std::make_unique<ShardingMetadata>(root_sharding),
514           std::make_unique<ShardingMetadata>(instruction_sharding)));
515   domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
516                           domain);
517   return domain;
518 }
519 
operator ==(const ShardingDomainCreator::DomainCseMapKey & other) const520 bool ShardingDomainCreator::DomainCseMapKey::operator==(
521     const ShardingDomainCreator::DomainCseMapKey& other) const {
522   if (instruction != other.instruction) {
523     return false;
524   }
525   if (sharding == nullptr && other.sharding == nullptr) {
526     return true;
527   }
528   if (sharding == nullptr || other.sharding == nullptr) {
529     return false;
530   }
531   return *sharding == *other.sharding;
532 }
533 
534 }  // namespace xla
535