1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26
27 namespace xla {
28
29 namespace {
30
31 // AssignmentKind and kUnassignedDevice are used during tuple domain sharding
32 // propagation in order to distinguish among three cases:
33 // kUnassigned: no assignment has occurred
34 // kAssigned: at least an assignment has occurred
35 // kConflict: no assignment has occurred because of conflicting propagations,
36 // which occurs when multiple users of an instruction have different
37 // shardings.
38 enum class AssignmentKind { kUnassigned, kAssigned, kConflict };
39
40 // kUnassignedDevice can only be assigned to tuple leaf shardings to indicate
41 // absence of sharding information for that particular sub-sharding during
42 // sharding propagation. It is used to be able to express tuple shardings with
43 // partial information. At the end of the propagation the sharding of
44 // tuple-shaped instructions using kUnassignedDevice's is cleared.
45 // TODO(b/112883246): Centralized enum of reserved devices.
46 constexpr int64_t kUnassignedDevice = -2;
47
48 struct PassThrough {
PassThroughxla::__anon55df75cf0111::PassThrough49 PassThrough(HloInstruction* user, HloInstruction* operand)
50 : user(user), operand(operand) {}
51
52 HloInstruction* user = nullptr;
53 HloInstruction* operand = nullptr;
54 };
55
SetSingleSharding(HloInstruction * instruction,const HloSharding & sharding)56 void SetSingleSharding(HloInstruction* instruction,
57 const HloSharding& sharding) {
58 VLOG(4) << " " << instruction->name() << " to " << sharding;
59 instruction->set_single_sharding(sharding);
60 }
61
ShardingMatches(const HloSharding & sharding1,const HloSharding & sharding2)62 bool ShardingMatches(const HloSharding& sharding1,
63 const HloSharding& sharding2) {
64 auto single_sharding1 = sharding1.ExtractSingleSharding();
65 if (single_sharding1) {
66 auto single_sharding2 = sharding2.ExtractSingleSharding();
67 if (single_sharding2) {
68 return *single_sharding1 == single_sharding2;
69 }
70 }
71 // Anything which is not unique across all elements, gets a full sharding
72 // compare.
73 return sharding1 == sharding2;
74 }
75
76 // When we create domains, they are never "empty", where with empty we mean
77 // that a kDomain instruction has as operand another kDomain instruction of the
78 // same kind.
79 // But when the HLO optimizations are run, empty domains can be created.
80 // For example:
81 //
82 // Domain(device=None, device=0) ->
83 // Tuple(device=0) ->
84 // GTE(device=0) ->
85 // Domain(device=0, device=None)
86 //
87 // In that case the tuple simplifier could create something like:
88 //
89 // Domain(device=None, device=0) -> Domain(device=0, device=None)
90 //
91 // Which is a so called empty domain.
92 // In the case above, crossing an empty domain which was transiting through
93 // device 0, requires the normalization phase to fixup the empty domain by
94 // adding back a Tuple+GTE pair with the proper device.
95 // One particular case where this can create problems is the result of the
96 // entry computation, where the GTE assignments are used by TF to tell the
97 // XLA where the results should be sent.
LocatePassThroughDomainLinks(const DomainMetadata::Domain & domain)98 std::vector<PassThrough> LocatePassThroughDomainLinks(
99 const DomainMetadata::Domain& domain) {
100 std::vector<PassThrough> pass_through;
101 for (HloInstruction* instruction : domain.enter_domains) {
102 CHECK(instruction->opcode() == HloOpcode::kDomain)
103 << "Instruction is not a kDomain: " << instruction->ToString();
104 for (HloInstruction* user : instruction->users()) {
105 if (user->opcode() == HloOpcode::kDomain &&
106 domain.exit_domains.contains(user)) {
107 pass_through.emplace_back(user, instruction);
108 VLOG(2) << "Found passthrough domain link:";
109 VLOG(2) << " " << user->ToString();
110 VLOG(2) << " " << instruction->ToString();
111 }
112 }
113 if (instruction == instruction->parent()->root_instruction()) {
114 pass_through.emplace_back(nullptr, instruction);
115 VLOG(2) << "Found passthrough domain link:";
116 VLOG(2) << " <root>";
117 VLOG(2) << " " << instruction->ToString();
118 }
119 }
120 return pass_through;
121 }
122
FixupPassThroughDomainLinks(const DomainMetadata::Domain & domain,const HloSharding & sharding)123 Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
124 const HloSharding& sharding) {
125 for (auto& pass_through : LocatePassThroughDomainLinks(domain)) {
126 HloInstruction* tuple = pass_through.operand->parent()->AddInstruction(
127 HloInstruction::CreateTuple({pass_through.operand}));
128 HloInstruction* gte = pass_through.operand->parent()->AddInstruction(
129 HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
130 tuple, 0));
131 gte->set_sharding(sharding);
132 if (pass_through.user != nullptr) {
133 TF_RETURN_IF_ERROR(
134 pass_through.operand->ReplaceUseWith(pass_through.user, gte));
135 } else {
136 pass_through.operand->parent()->set_root_instruction(gte);
137 }
138 }
139 return OkStatus();
140 }
141
142 // For tuple shardings if every element have the same sharsing then we want to
143 // treat them as single element sharsings to insert less domain separation as a
144 // domain can prevent some optimizations and we want to minimize that from
145 // happening.
CloneShardingForDomain(std::shared_ptr<const HloSharding> sharding)146 std::shared_ptr<const HloSharding> CloneShardingForDomain(
147 std::shared_ptr<const HloSharding> sharding) {
148 auto single_sharding = sharding->ExtractSingleSharding();
149 if (!single_sharding) {
150 return sharding;
151 }
152 return std::make_shared<const HloSharding>(*single_sharding);
153 }
154
ApplyDomainSingleSharding(const DomainMetadata::Domain & domain,const HloSharding & sharding)155 Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
156 const HloSharding& sharding) {
157 VLOG(4) << "Applying " << sharding << " sharding";
158 for (HloInstruction* instruction : domain.instructions) {
159 // We only change instructions without sharding, since otherwise we might
160 // mess up with eventual HLO passes which has knowledge of it.
161 if (!instruction->has_sharding()) {
162 SetSingleSharding(instruction, sharding);
163 } else {
164 VLOG(4) << " " << instruction->name() << " already has sharding "
165 << instruction->sharding();
166 }
167 }
168 return OkStatus();
169 }
170
171 // Return the ShapeTree<HloSharding> of the user argument. The user argument
172 // is assumed to be a user of the instruction argument.
173 // If user is a tuple instruction, return the tuple subsharding corresponding to
174 // the operand matching the instruction argument, because that is the
175 // subsharding corresponding to instruction.
GetShardingTreeFromUser(const HloInstruction & instruction,const HloInstruction & user)176 StatusOr<ShapeTree<HloSharding>> GetShardingTreeFromUser(
177 const HloInstruction& instruction, const HloInstruction& user) {
178 if (user.opcode() == HloOpcode::kTuple) {
179 return user.sharding()
180 .GetSubSharding(user.shape(), {user.operand_index(&instruction)})
181 .AsShapeTree(instruction.shape());
182 }
183 return user.sharding().AsShapeTree(user.shape());
184 }
185
186 // Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice)
187 // then no assignment is made. Therefore kUnassignedDevice is never propagated.
188 // kConflict is returned if lhs is already assigned and rhs is assigned to a
189 // different device.
AssignLeafSharding(HloSharding * lhs,const HloSharding & rhs)190 StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs,
191 const HloSharding& rhs) {
192 TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple());
193 if (rhs.UsesDevice(kUnassignedDevice)) {
194 return AssignmentKind::kUnassigned;
195 }
196 if (lhs->UsesDevice(kUnassignedDevice)) {
197 *lhs = rhs;
198 return AssignmentKind::kAssigned;
199 }
200 return lhs->UniqueDevice() != rhs.UniqueDevice()
201 ? AssignmentKind::kConflict
202 : AssignmentKind::kUnassigned;
203 }
204
205 // Assigns the whole rhs tree to lhs_tree, starting at lhs_it.
206 // In case of conflicting assignment AssignmentKind::kConflict is returned. In
207 // this case lhs_tree is partially assigned, up to the conflicting leaf. It is
208 // up to the caller to discard the partial assignment in case of conflict.
AssignTreeSharding(ShapeTree<HloSharding> * lhs_tree,ShapeTree<HloSharding>::iterator lhs_it,const ShapeTree<HloSharding> & rhs_tree)209 StatusOr<AssignmentKind> AssignTreeSharding(
210 ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it,
211 const ShapeTree<HloSharding>& rhs_tree) {
212 AssignmentKind assigned = AssignmentKind::kUnassigned;
213 auto rhs_it = rhs_tree.begin();
214 for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end();
215 ++lhs_it, ++rhs_it) {
216 // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it)
217 if (rhs_tree.IsLeaf(rhs_it->first)) {
218 TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first));
219 TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned,
220 AssignLeafSharding(&lhs_it->second, rhs_it->second));
221 if (sub_assigned == AssignmentKind::kConflict) {
222 // In case of conflict we return conflict to the caller. At this point
223 // partial assignments to lhs_tree may have been made already. It is up
224 // to the caller to discard the partial assignment in case of conflict.
225 return AssignmentKind::kConflict;
226 } else if (sub_assigned == AssignmentKind::kAssigned) {
227 assigned = sub_assigned;
228 }
229 }
230 }
231 TF_RET_CHECK(rhs_it == rhs_tree.end());
232 return assigned;
233 }
234
ApplyShardingFromUsers(HloInstruction * instruction,const DomainMetadata::Domain & domain,const HloSharding & domain_sharding)235 StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
236 const DomainMetadata::Domain& domain,
237 const HloSharding& domain_sharding) {
238 if (instruction->users().empty()) {
239 // No sharding from users, use domain_sharding, after checking
240 // compatibility.
241 TF_RET_CHECK(instruction->shape().IsTuple() &&
242 ShapeUtil::GetLeafCount(instruction->shape()) ==
243 domain_sharding.tuple_elements().size());
244 instruction->set_sharding(domain_sharding);
245 return true;
246 }
247 AssignmentKind assigned = AssignmentKind::kUnassigned;
248 // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple
249 // subshardings can result in a final sharding assignment containing
250 // kUnassignedDevice leaves, in case some tuple indexes are not used, or are
251 // used by users that don't have a sharding.
252 // Non-tuple shardings are either assigned to a real sharding, or are not
253 // assigned at all. As such they will never get assigned to kUnassignedDevice.
254 // In any case, kUnassignedDevice is never propagated, from the implementation
255 // of AssignLeafSharding.
256 ShapeTree<HloSharding> sharding_tree(
257 instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
258 for (HloInstruction* user : instruction->users()) {
259 if (user->opcode() == HloOpcode::kDomain &&
260 domain.exit_domains.contains(user)) {
261 // If a user is a domain and it is registered in the domain exits, then
262 // the instruction sharding is taken directly from the domain, and no
263 // further users need to be visited.
264 instruction->set_sharding(domain_sharding);
265 return true;
266 }
267 if (!user->has_sharding()) {
268 continue;
269 }
270 AssignmentKind sub_assigned = AssignmentKind::kUnassigned;
271 TF_ASSIGN_OR_RETURN(ShapeTree<HloSharding> user_sharding_tree,
272 GetShardingTreeFromUser(*instruction, *user));
273 if (instruction->shape().IsTuple()) {
274 // For tuple-shaped instructions collect individual tuple subshardings
275 // from the uses, and then combine them into the tuple sharding.
276 // If the user is a GTE its sharding concerns only the subtree of
277 // sharding_tree at index user->tuple_index, otherwise the whole
278 // sharding_tree is affected.
279 ShapeTree<HloSharding>::iterator sharding_tree_begin =
280 user->opcode() == HloOpcode::kGetTupleElement
281 ? sharding_tree.find({user->tuple_index()})
282 : sharding_tree.begin();
283 TF_ASSIGN_OR_RETURN(
284 sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin,
285 user_sharding_tree));
286 } else {
287 // Non-tuple shape: assign common users sharding.
288 TF_RET_CHECK(user_sharding_tree.leaf_count() == 1)
289 << "Expected non-tuple user sharding";
290 TF_ASSIGN_OR_RETURN(
291 sub_assigned,
292 AssignTreeSharding(&sharding_tree, sharding_tree.begin(),
293 user_sharding_tree));
294 }
295
296 if (sub_assigned == AssignmentKind::kConflict) {
297 // In case of conflict we don't assign any sharding.
298 return false;
299 } else if (sub_assigned == AssignmentKind::kAssigned) {
300 assigned = sub_assigned;
301 }
302 }
303
304 if (assigned == AssignmentKind::kAssigned) {
305 if (instruction->shape().IsTuple()) {
306 instruction->set_sharding(HloSharding::Tuple(sharding_tree));
307 } else {
308 TF_RET_CHECK(sharding_tree.leaf_count() == 1);
309 instruction->set_sharding(sharding_tree.leaf_begin()->second);
310 }
311 return true;
312 }
313 return false;
314 }
315
316 // Tries to propagate the sharding information into the instructions that are
317 // part of the domain, in a reverse post order manner (users propagate to
318 // instruction).
ApplyDomainShardingPass(const DomainMetadata::Domain & domain,const HloSharding & domain_sharding)319 StatusOr<int64_t> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
320 const HloSharding& domain_sharding) {
321 int64_t assigned = 0;
322 // domain.instructions are ordered in a post-order manner. As we do
323 // user->operand propagation we process instructions in reverse order. In so
324 // doing we are guaranteed to process all users before their operands.
325 for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend();
326 ++it) {
327 HloInstruction* instruction = *it;
328 if (instruction->has_sharding()) {
329 continue;
330 }
331 // Take the sharding from the users.
332 TF_ASSIGN_OR_RETURN(
333 bool instruction_assigned,
334 ApplyShardingFromUsers(instruction, domain, domain_sharding));
335 if (instruction_assigned) {
336 ++assigned;
337 VLOG(4) << " " << instruction->name() << " to sharding "
338 << instruction->sharding();
339 }
340 }
341 return assigned;
342 }
343
ApplyDomainSharding(const DomainMetadata::Domain & domain,const HloSharding & sharding)344 Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
345 const HloSharding& sharding) {
346 // None of the external normalizers handled the domain sharding, try to see
347 // whether this is a single sharding first.
348 auto single_sharding = sharding.ExtractSingleSharding();
349 if (single_sharding) {
350 // Shortcut the simple case. We have a unique sharding, so we call
351 // the ApplyDomainSingleSharding() API which will apply array or tuple
352 // shaped sharding to the domain instructions.
353 return ApplyDomainSingleSharding(domain, *single_sharding);
354 }
355 VLOG(1) << "Assigning non-trivial sharding " << sharding;
356 TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status());
357
358 int64_t unassigned = 0;
359 for (HloInstruction* instruction : domain.instructions) {
360 if (!instruction->has_sharding()) {
361 LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
362 ++unassigned;
363 } else {
364 // Un-set sharding of tuples whose sub-shardings are assigned to
365 // kUnassignedDevice. Indeed in case of doubt it is better to leave the
366 // entire tuple unassigned, and let the device placer decide for it.
367 // Do not clear the tuple sharding when the instruction is kParameter. The
368 // sharding of the tuple might not be able to reconstructed if its users
369 // are removed during DCE.
370 if (instruction->sharding().UsesDevice(kUnassignedDevice) &&
371 instruction->opcode() != HloOpcode::kParameter) {
372 TF_RET_CHECK(instruction->shape().IsTuple())
373 << "Only tuples can have kUnassignedDevice sub shardings";
374 instruction->clear_sharding();
375 }
376 }
377 }
378 // Should we error out if unassigned > 0?
379 return OkStatus();
380 }
381
ExtractOriginalCommonSharding(absl::Span<HloInstruction * const> instructions)382 StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
383 absl::Span<HloInstruction* const> instructions) {
384 // If we are here, all the instructions being passed had the same sharding
385 // (or no sharding), by the means of the ShardingMatches() API.
386 // As such, no kDomain was inserted, and here we are asked to extract the
387 // original common sharding.
388 // All the instructions passed to this API are part of the same computation.
389 std::shared_ptr<const HloSharding> sharding;
390 for (HloInstruction* instruction : instructions) {
391 if (instruction->has_sharding()) {
392 if (sharding == nullptr) {
393 sharding = instruction->sharding_ptr();
394 } else {
395 TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
396 << "Sharding " << *sharding << " does not match the one in "
397 << instruction->ToString();
398 }
399 }
400 }
401 if (sharding == nullptr) {
402 return std::shared_ptr<const HloSharding>();
403 }
404 VLOG(4) << "Extracted sharding is " << *sharding;
405 return CloneShardingForDomain(sharding);
406 }
407
408 } // namespace
409
Clone() const410 std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
411 std::unique_ptr<HloSharding> sharding;
412 if (sharding_ != nullptr) {
413 sharding = std::make_unique<HloSharding>(*sharding_);
414 }
415 return std::make_unique<ShardingMetadata>(std::move(sharding));
416 }
417
Matches(const DomainMetadata & other) const418 bool ShardingMetadata::Matches(const DomainMetadata& other) const {
419 const ShardingMetadata* other_ptr =
420 dynamic_cast<const ShardingMetadata*>(&other);
421 if (other_ptr == nullptr) {
422 // If other is not a ShardingMetadata, then it is clearly a no match.
423 return false;
424 }
425 if (sharding_ == nullptr) {
426 return other_ptr->sharding_ == nullptr;
427 }
428 return other_ptr->sharding_ != nullptr
429 ? ShardingMatches(*sharding_, *other_ptr->sharding_)
430 : false;
431 }
432
ToString() const433 std::string ShardingMetadata::ToString() const {
434 return sharding_ != nullptr ? sharding_->ToString() : "{}";
435 }
436
437 /*static*/ StatusOr<const ShardingMetadata*>
ToShardingMetadata(const DomainMetadata * metadata)438 ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) {
439 if (metadata->Kind() != ShardingMetadata::KindName()) {
440 return Status(
441 tensorflow::error::INVALID_ARGUMENT,
442 "ShardingMetadata normalizer called with incorrect domain metadata");
443 }
444 return static_cast<const ShardingMetadata*>(metadata);
445 }
446
NormalizeShardingDomain(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)447 Status ShardingMetadata::NormalizeShardingDomain(
448 const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
449 if (metadata != nullptr) {
450 TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
451 ToShardingMetadata(metadata));
452 const HloSharding* sharding = sharding_metadata->sharding();
453 if (sharding != nullptr) {
454 VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":";
455 TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
456 TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
457 }
458 } else {
459 TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
460 ExtractOriginalCommonSharding(domain.instructions));
461 if (sharding != nullptr) {
462 VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
463 TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
464 } else {
465 VLOG(1) << "Unable to find common sharding";
466 }
467 }
468 return OkStatus();
469 }
470
471 // Creates a kDomain instruction to be placed between instruction and operand.
472 // The kDomain instruction will be created only if the sharding differ between
473 // the instruction and the operand.
operator ()(HloInstruction * instruction,HloInstruction * root,HloInstruction * operand)474 HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
475 HloInstruction* root,
476 HloInstruction* operand) {
477 auto instruction_sharding = instruction->sharding_ptr();
478 auto root_sharding = root->sharding_ptr();
479 // No need for domain if they both have no sharding.
480 if (instruction_sharding == nullptr && root_sharding == nullptr) {
481 return nullptr;
482 }
483 // No need for domain if they match.
484 if (instruction_sharding != nullptr && root_sharding != nullptr &&
485 ShardingMatches(*instruction_sharding, *root_sharding)) {
486 return nullptr;
487 }
488
489 if (instruction_sharding != nullptr) {
490 instruction_sharding = CloneShardingForDomain(instruction_sharding);
491 }
492 if (root_sharding != nullptr) {
493 root_sharding = CloneShardingForDomain(root_sharding);
494 }
495
496 auto it = domain_cse_map_.find({operand, instruction_sharding});
497 if (it != domain_cse_map_.end()) {
498 return it->second;
499 }
500
501 VLOG(3) << "Creating domain:";
502 VLOG(3) << " Instruction: " << instruction->name();
503 VLOG(3) << " Operand: " << operand->name();
504 VLOG(3) << " User side sharding: "
505 << (instruction_sharding != nullptr ? instruction_sharding->ToString()
506 : "None");
507 VLOG(3) << " Operand side sharding: "
508 << (root_sharding != nullptr ? root_sharding->ToString() : "None");
509
510 HloInstruction* domain =
511 operand->parent()->AddInstruction(HloInstruction::CreateDomain(
512 operand->shape(), operand,
513 std::make_unique<ShardingMetadata>(root_sharding),
514 std::make_unique<ShardingMetadata>(instruction_sharding)));
515 domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
516 domain);
517 return domain;
518 }
519
operator ==(const ShardingDomainCreator::DomainCseMapKey & other) const520 bool ShardingDomainCreator::DomainCseMapKey::operator==(
521 const ShardingDomainCreator::DomainCseMapKey& other) const {
522 if (instruction != other.instruction) {
523 return false;
524 }
525 if (sharding == nullptr && other.sharding == nullptr) {
526 return true;
527 }
528 if (sharding == nullptr || other.sharding == nullptr) {
529 return false;
530 }
531 return *sharding == *other.sharding;
532 }
533
534 } // namespace xla
535