1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/cleanup/cleanup.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_dce.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
29 #include "tensorflow/compiler/xla/shape_tree.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
BFloat16Propagation(const BFloat16Support * bfloat16_support)35 BFloat16Propagation::BFloat16Propagation(
36 const BFloat16Support* bfloat16_support)
37 : bfloat16_support_(bfloat16_support) {}
38
DetermineFusionComputationPrecision(HloInstruction * fusion)39 void BFloat16Propagation::DetermineFusionComputationPrecision(
40 HloInstruction* fusion) {
41 CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
42 if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
43 return;
44 }
45
46 // We are depending on the fusion node itself having already been analyzed
47 // for whether it can output BF16 and this has been adjusted in the output
48 // shape, and now we're looking to update the interior of the fusion node to
49 // match the new output shape, as well as recursively process the whole fusion
50 // node even if the output shape was not modified.
51 auto root = fusion->fused_instructions_computation()->root_instruction();
52
53 // Adjust root's element types according to the fusion's output shape.
54 ShapeUtil::ForEachSubshape(
55 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
56 if (subshape.element_type() != F32) {
57 return;
58 }
59 if (OutputTypeAfterChange(fusion, index) == BF16) {
60 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
61 VLOG(2) << "Fused root " << root->ToString() << " at shape index "
62 << index << " changed to BF16 precision for fusion "
63 << fusion->ToString();
64 }
65 });
66
67 // Propagate BF16 in the fusion computation.
68 auto insts =
69 fusion->fused_instructions_computation()->MakeInstructionPostOrder();
70 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
71 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
72 }
73 computations_visited_in_backward_pass_.insert(
74 fusion->fused_instructions_computation());
75
76 RevertIfFusionInternalBF16Changes(fusion);
77 }
78
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)79 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
80 HloInstruction* fusion) {
81 auto has_changes = [this](HloInstruction* inst) {
82 auto it = changes_to_bf16_.find(inst);
83 return it != changes_to_bf16_.end() && !it->second.empty();
84 };
85
86 auto root = fusion->fused_instructions_computation()->root_instruction();
87 absl::flat_hash_set<const HloValue*> changed_root_buffers;
88
89 auto root_changes_it = changes_to_bf16_.find(root);
90 if (root_changes_it != changes_to_bf16_.end()) {
91 for (const auto& entry : root_changes_it->second) {
92 for (const HloValue* value :
93 dataflow_->GetValueSet(root, entry.second).values()) {
94 changed_root_buffers.insert(value);
95 }
96 }
97 }
98
99 auto aliases_changed_root_buffer = [this, &changed_root_buffers](
100 const HloInstruction* inst) {
101 bool aliasing = false;
102 ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape,
103 const ShapeIndex& index) {
104 if (aliasing) {
105 // Skip if aliasing is already found.
106 return;
107 }
108 // Only F32 buffers are considered for changing to BF16 in this
109 // pass.
110 if (subshape.element_type() != F32) {
111 return;
112 }
113
114 aliasing = absl::c_any_of(dataflow_->GetValueSet(inst, index).values(),
115 IsValueIn(changed_root_buffers));
116 });
117 return aliasing;
118 };
119
120 for (auto inst :
121 fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
122 if (inst->opcode() == HloOpcode::kParameter) {
123 continue;
124 }
125 if (aliases_changed_root_buffer(inst)) {
126 continue;
127 }
128 if (inst->opcode() == HloOpcode::kFusion) {
129 bool parameter_reverted = false;
130 for (int64_t i = 0; i < inst->operand_count(); ++i) {
131 if (has_changes(inst->mutable_operand(i))) {
132 // Changes on the operand have not been reverted.
133 continue;
134 }
135 auto* fused_parameter = inst->fused_parameter(i);
136 if (has_changes(fused_parameter)) {
137 changes_to_bf16_.erase(fused_parameter);
138 parameter_reverted = true;
139 }
140 }
141 if (parameter_reverted) {
142 RevertIfFusionInternalBF16Changes(inst);
143 }
144 }
145 if (!has_changes(inst)) {
146 continue;
147 }
148 bool revert_changes = true;
149 for (auto operand : inst->operands()) {
150 if (has_changes(operand)) {
151 revert_changes = false;
152 break;
153 }
154 }
155 if (revert_changes) {
156 changes_to_bf16_.erase(inst);
157 }
158 }
159 }
160
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)161 void BFloat16Propagation::DetermineWhileComputationsPrecision(
162 HloInstruction* while_hlo) {
163 CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
164
165 // We are depending on the while node itself having already been analyzed for
166 // whether it can output BF16 and this has been adjusted in the output shape,
167 // and now we're looking to update the body and condition computations to
168 // match the new output shape, as well as recursively process the whole while
169 // node even if the output shape was not modified.
170 HloComputation* body = while_hlo->while_body();
171 auto body_root = body->root_instruction();
172 HloComputation* condition = while_hlo->while_condition();
173
174 ShapeUtil::ForEachSubshape(
175 body_root->shape(), [this, while_hlo, body_root](
176 const Shape& subshape, const ShapeIndex& index) {
177 if (subshape.element_type() != F32) {
178 return;
179 }
180 if (OutputTypeAfterChange(while_hlo, index) == BF16) {
181 AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
182 VLOG(2) << "While body root " << body_root->ToString()
183 << " at shape index " << index
184 << " changed to BF16 precision for while "
185 << while_hlo->ToString();
186 }
187 });
188
189 auto body_insts = body->MakeInstructionPostOrder();
190 for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
191 ++inst_it) {
192 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
193 }
194 computations_visited_in_backward_pass_.insert(body);
195
196 auto condition_insts = condition->MakeInstructionPostOrder();
197 for (auto inst_it = condition_insts.rbegin();
198 inst_it != condition_insts.rend(); ++inst_it) {
199 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
200 }
201 computations_visited_in_backward_pass_.insert(condition);
202 }
203
DetermineConditionalComputationsPrecision(HloInstruction * cond)204 void BFloat16Propagation::DetermineConditionalComputationsPrecision(
205 HloInstruction* cond) {
206 CHECK_EQ(cond->opcode(), HloOpcode::kConditional);
207 for (int64_t i = 0; i < cond->branch_count(); ++i) {
208 auto branch = cond->branch_computation(i);
209 auto root = branch->root_instruction();
210 ShapeUtil::ForEachSubshape(
211 root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
212 if (subshape.element_type() != F32) {
213 return;
214 }
215 if (OutputTypeAfterChange(cond, index) == BF16) {
216 AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
217 VLOG(2) << "Conditional branch " << i << " root "
218 << root->ToString() << " at shape index " << index
219 << " changed to BF16 precision for conditional "
220 << cond->ToString();
221 }
222 });
223 auto insts = branch->MakeInstructionPostOrder();
224 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
225 DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
226 }
227 computations_visited_in_backward_pass_.insert(branch);
228 }
229 }
230
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const231 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
232 const ShapeIndex& index) const {
233 // If the subshape isn't floating point then none of the users will be BF16.
234 const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
235 if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
236 return false;
237 }
238
239 auto& value_set = dataflow_->GetValueSet(&hlo, index);
240 for (const HloValue* value : value_set.values()) {
241 if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
242 return false;
243 }
244 // We use the original type for the value because we are going to examine
245 // the uses of it, instead of the value itself. If ValueTypeAfterChange()
246 // were used, it would cause problems when there are aliasing buffers, i.e.,
247 // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
248 // tentative change to BF16 even if the uses require F32.
249 if (value->shape().element_type() == BF16) {
250 continue;
251 }
252 for (const HloUse& use : value->GetUses()) {
253 if (!ContainsKey(instructions_visited_in_backward_pass_,
254 use.instruction)) {
255 // We don't know yet whether use.instruction will consume BF16 since it
256 // hasn't been visited. Although we visit instructions in reverse
257 // topological order, this is still possible because there may be
258 // unvisited instruction that alias the same buffer. In this case, we
259 // aggressively skip this use, and if this causes inconsistency (e.g.,
260 // one use is in BF16 but another use is in F32), it will be resolved at
261 // the end of the BFloat16Propagation pass.
262 continue;
263 }
264 if (use.instruction->HasSideEffectNoRecurse()) {
265 // Keep side-effecting instruction's operands unchanged.
266 return false;
267 }
268 // Any visited user that can accept BF16 has already been updated if
269 // necessary, e.g., the output has been changed to BF16 if it propagates
270 // precision, or a called computation's parameters have been changed to
271 // BF16 for fusions or whiles.
272 if (use.instruction->opcode() == HloOpcode::kFusion) {
273 auto* fused_parameter =
274 use.instruction->fused_parameter(use.operand_number);
275 if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
276 return false;
277 }
278 continue;
279 } else if (use.instruction->opcode() == HloOpcode::kWhile) {
280 auto* cond_parameter =
281 use.instruction->while_condition()->parameter_instruction(
282 use.operand_number);
283 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
284 return false;
285 }
286 auto* body_parameter =
287 use.instruction->while_body()->parameter_instruction(
288 use.operand_number);
289 if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
290 return false;
291 }
292 continue;
293 } else if (use.instruction->opcode() == HloOpcode::kConditional) {
294 auto* cond_parameter =
295 use.instruction->branch_computation(use.operand_number - 1)
296 ->parameter_instruction(0);
297 if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
298 return false;
299 }
300 continue;
301 }
302 if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
303 *use.instruction, use.operand_number)) {
304 continue;
305 }
306 // If the op propagates precision and it outputs a BF16, then it's OK to
307 // supply BF16 also as the input. In the backward pass, the users shapes
308 // should have already been processed.
309 if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
310 *use.instruction, use.operand_number)) {
311 if (use.instruction->opcode() == HloOpcode::kTuple ||
312 (use.instruction->opcode() == HloOpcode::kAllReduce &&
313 use.instruction->shape().IsTuple())) {
314 ShapeIndex use_output_index{use.operand_number};
315 for (int64_t i : use.operand_index) {
316 use_output_index.push_back(i);
317 }
318 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
319 BF16) {
320 continue;
321 }
322 } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
323 ShapeIndex use_output_index;
324 for (int64_t i = 1; i < use.operand_index.size(); ++i) {
325 use_output_index.push_back(use.operand_index[i]);
326 }
327 if (OutputTypeAfterChange(use.instruction, use_output_index) ==
328 BF16) {
329 continue;
330 }
331 } else {
332 if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
333 BF16) {
334 continue;
335 }
336 }
337 }
338 return false;
339 }
340 }
341 return true;
342 }
343
ShouldKeepPrecisionUnchanged(const HloInstruction * inst)344 bool BFloat16Propagation::ShouldKeepPrecisionUnchanged(
345 const HloInstruction* inst) {
346 if (inst->opcode() == HloOpcode::kFusion &&
347 inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
348 return ShouldKeepPrecisionUnchanged(
349 inst->fused_instructions_computation()->root_instruction());
350 }
351 // Do not change precision for side-effecting instructions, control flow, and
352 // bitcast-convert, because this pass might break the interfaces or
353 // assumptions for them.
354 return inst->opcode() == HloOpcode::kCustomCall ||
355 inst->opcode() == HloOpcode::kCall ||
356 inst->opcode() == HloOpcode::kBitcastConvert ||
357 inst->HasSideEffectNoRecurse();
358 }
359
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)360 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
361 bool skip_parameters) {
362 // We handle any fusion computation, while body/condition or conditional
363 // branches after the instruction is handled, because we need to know the
364 // output shape of a fusion or while before propagating inside its
365 // computations.
366 bool postpone_processing_called_computations = false;
367 absl::Cleanup cleaner = [this, hlo,
368 &postpone_processing_called_computations] {
369 if (!postpone_processing_called_computations) {
370 if (hlo->opcode() == HloOpcode::kFusion) {
371 DetermineFusionComputationPrecision(hlo);
372 } else if (hlo->opcode() == HloOpcode::kWhile) {
373 DetermineWhileComputationsPrecision(hlo);
374 } else if (hlo->opcode() == HloOpcode::kConditional) {
375 DetermineConditionalComputationsPrecision(hlo);
376 }
377 }
378 instructions_visited_in_backward_pass_.insert(hlo);
379 };
380
381 if (hlo->opcode() == HloOpcode::kWhile &&
382 (caller_counts_[hlo->while_condition()] > 1 ||
383 caller_counts_[hlo->while_body()] > 1)) {
384 postpone_processing_called_computations = true;
385 return;
386 }
387
388 if (hlo->opcode() == HloOpcode::kConditional &&
389 absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) {
390 return caller_counts_[c] > 1;
391 })) {
392 postpone_processing_called_computations = true;
393 return;
394 }
395
396 // Prevent root instructions from having their output modified by recording
397 // all F32 output values as needing to stay as F32.
398 CHECK(hlo->parent() != nullptr);
399 if (hlo == hlo->parent()->root_instruction()) {
400 if (!hlo->parent()->IsFusionComputation()) {
401 ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
402 const ShapeIndex& index) {
403 if (OutputTypeAfterChange(hlo, index) != F32) {
404 return;
405 }
406 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
407 // Since we use HloValues from the dataflow analysis, this can also
408 // affect HLO instructions beyond the root, e.g., if the root is a
409 // Tuple HLO, then its operands are also affected.
410 values_that_must_be_kept_as_f32_.insert(value);
411 }
412 });
413 }
414 return;
415 }
416
417 if (ShouldKeepPrecisionUnchanged(hlo) ||
418 (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
419 return;
420 }
421
422 if (!ContainsKey(consider_using_bfloat16_, hlo)) {
423 return;
424 }
425
426 if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
427 return;
428 }
429
430 ShapeUtil::ForEachSubshape(
431 hlo->shape(),
432 [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
433 if (OutputTypeAfterChange(hlo, index) == F32 &&
434 AllUsersConsumeBF16(*hlo, index)) {
435 AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
436 VLOG(2) << "HloInstruction output at shape index " << index
437 << " changed to BF16 precision: " << hlo->ToString();
438 }
439 });
440 }
441
InstructionIsCandidateForBF16Output(HloInstruction * hlo)442 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
443 HloInstruction* hlo) {
444 if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
445 hlo->opcode() != HloOpcode::kTuple &&
446 hlo->opcode() != HloOpcode::kGetTupleElement &&
447 hlo->opcode() != HloOpcode::kDomain &&
448 hlo->shape().element_type() != BF16) {
449 for (int64_t i = 0; i < hlo->operand_count(); ++i) {
450 if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
451 i) ||
452 !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
453 return false;
454 }
455 }
456 }
457 return true;
458 }
459
AdjustCalledComputationParameters(HloInstruction * hlo)460 void BFloat16Propagation::AdjustCalledComputationParameters(
461 HloInstruction* hlo) {
462 auto adjust_computation = [this, hlo](
463 HloComputation* computation,
464 absl::Span<HloInstruction* const> operands) {
465 // Adjust parameters.
466 CHECK_EQ(operands.size(), computation->num_parameters());
467 for (int64_t i = 0; i < operands.size(); ++i) {
468 auto parameter = computation->parameter_instruction(i);
469 ShapeUtil::ForEachSubshape(
470 parameter->shape(),
471 [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
472 const ShapeIndex& index) {
473 if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
474 return;
475 }
476 PrimitiveType operand_type =
477 OutputTypeAfterChange(operands[i], index);
478 if (OutputTypeAfterChange(parameter, index) == operand_type) {
479 return;
480 }
481 AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
482 VLOG(2) << "Called computation parameter " << parameter->ToString()
483 << " at shape index " << index << " adjusted to "
484 << (operand_type == BF16 ? "BF16" : "F32")
485 << " to match operand in HLO " << hlo->ToString();
486 });
487 }
488 };
489
490 switch (hlo->opcode()) {
491 case HloOpcode::kFusion:
492 adjust_computation(hlo->fused_instructions_computation(),
493 hlo->operands());
494 break;
495 case HloOpcode::kWhile:
496 adjust_computation(hlo->while_condition(), hlo->operands());
497 adjust_computation(hlo->while_body(), hlo->operands());
498 break;
499 case HloOpcode::kConditional:
500 for (int64_t i = 0; i < hlo->branch_count(); ++i) {
501 adjust_computation(hlo->branch_computation(i),
502 {hlo->mutable_operand(i + 1)});
503 }
504 break;
505 default:
506 break;
507 }
508 }
509
AdjustCalledComputationRoot(HloInstruction * hlo)510 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
511 auto adjust_computation = [this, hlo](HloComputation* computation,
512 HloInstruction* output) {
513 // Adjust root.
514 HloInstruction* root = computation->root_instruction();
515 ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
516 const Shape& /* subshape */,
517 const ShapeIndex& index) {
518 if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
519 return;
520 }
521 const PrimitiveType output_type = OutputTypeAfterChange(output, index);
522 if (OutputTypeAfterChange(root, index) == output_type) {
523 return;
524 }
525 AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
526 // It's possible that output_type is F32, but the root instruction's
527 // type is BF16; e.g., a fusion node's output was changed to BF16
528 // initially but then adjusted back to F32, and the fusion computation
529 // is now being adjusted after the fusion node.
530 if (output_type == F32) {
531 for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
532 // We rely on the fact that this adjustment works in reverse
533 // topological order so that called computation will be
534 // processed later. Adding the value to
535 // values_that_must_be_kept_as_f32_ will ensure the
536 // correctness of the adjustment for HLOs that will be
537 // processed later.
538 values_that_must_be_kept_as_f32_.insert(value);
539 }
540 }
541 VLOG(2) << "Called computation root " << root->ToString()
542 << " at shape index " << index << " adjusted to "
543 << (output_type == BF16 ? "BF16" : "F32")
544 << " to match output shape of " << hlo->ToString();
545 });
546 };
547
548 switch (hlo->opcode()) {
549 case HloOpcode::kFusion:
550 adjust_computation(hlo->fused_instructions_computation(), hlo);
551 break;
552 case HloOpcode::kWhile:
553 adjust_computation(hlo->while_body(), hlo);
554 break;
555 case HloOpcode::kConditional:
556 for (auto* branch : hlo->branch_computations()) {
557 adjust_computation(branch, hlo);
558 }
559 break;
560 default:
561 break;
562 }
563 }
564
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)565 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
566 HloComputation* computation,
567 absl::flat_hash_set<const HloComputation*>* visited_computations) {
568 bool parameter_changed = false;
569 auto insts = computation->MakeInstructionPostOrder();
570 // Do the adjustment on each instruction in the computation in reverse
571 // topological order.
572 while (true) {
573 bool any_change = false;
574 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
575 auto hlo = *inst_it;
576 auto adjust_hlo_output = [&](const Shape& /* subshape */,
577 const ShapeIndex& index) {
578 auto output_type = OutputTypeAfterChange(hlo, index);
579 VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
580 << " for :" << hlo->ToString() << "\n";
581 if (output_type != F32 && output_type != BF16) {
582 return;
583 }
584 PrimitiveType type = BF16;
585 for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
586 auto value_type = ValueTypeAfterChange(value);
587 if (value_type == BF16) {
588 continue;
589 }
590 VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
591 << value->ToString() << "\n";
592 CHECK_EQ(value_type, F32);
593 type = F32;
594 break;
595 }
596 // In order to find aliases due to in-place operations, use
597 // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
598 // but this code works with HloModules that aren't ready yet to use
599 // HloAliasAnalysis (e.g., their computation graphs may not have been
600 // flattened yet).
601 for (const auto& operand_and_output_index :
602 HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
603 if (operand_and_output_index.second == index) {
604 const HloOperandIndex& operand_index =
605 operand_and_output_index.first;
606 for (const auto* value :
607 dataflow_
608 ->GetValueSet(hlo->operand(operand_index.operand_number),
609 operand_index.operand_index)
610 .values()) {
611 auto value_type = ValueTypeAfterChange(value);
612 if (value_type == BF16) {
613 continue;
614 }
615 VLOG(2) << "Adjust to F32 due to InputOutPair: "
616 << value->ToString() << "\n";
617 CHECK_EQ(value_type, F32);
618 type = F32;
619 break;
620 }
621 }
622 }
623
624 // It's possible that a user has been changed from BF16 to F32
625 // during this final adjustment pass, so we need to check
626 // AllUsersConsumeBF16() again.
627 if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
628 VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
629 type = F32;
630 }
631 if (type == F32) {
632 for (const auto* value :
633 dataflow_->GetValueSet(hlo, index).values()) {
634 // We rely on the fact that this adjustment works in reverse
635 // topological order. Adding the value to
636 // values_that_must_be_kept_as_f32_ will ensure the correctness
637 // of the adjustment for HLOs that will be processed later.
638 values_that_must_be_kept_as_f32_.insert(value);
639 }
640 }
641 if (type != output_type) {
642 any_change = true;
643 AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
644 VLOG(2) << "HloInstruction output at shape index " << index
645 << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
646 << hlo->ToString();
647 if (hlo->opcode() == HloOpcode::kParameter) {
648 parameter_changed = true;
649 }
650 }
651 };
652 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
653 AdjustCalledComputationRoot(hlo);
654 if (hlo->opcode() == HloOpcode::kWhile) {
655 // We need to run on the while body and condition repeatedly until a
656 // fixed point is reached, i.e., the parameters do not change any more.
657 // We may need more than one iteration because the while input and
658 // output alias each other, so changing one input parameter requires
659 // changing the corresponding output element and thus may transitively
660 // require changing another input parameter. A fixed point will be
661 // reached because the parameters can only be changed from BF16 to F32,
662 // not the other way around.
663 absl::flat_hash_set<const HloComputation*> visited_in_while;
664 while (ResolveInconsistencyOfAliasingBuffersHelper(
665 hlo->while_condition(), &visited_in_while) ||
666 ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
667 &visited_in_while)) {
668 visited_in_while.clear();
669 ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
670 AdjustCalledComputationRoot(hlo);
671 }
672 visited_computations->insert(visited_in_while.begin(),
673 visited_in_while.end());
674 } else if (hlo->opcode() == HloOpcode::kFusion) {
675 ResolveInconsistencyOfAliasingBuffersHelper(
676 hlo->fused_instructions_computation(), visited_computations);
677 } else if (hlo->opcode() == HloOpcode::kConditional) {
678 for (auto* branch : hlo->branch_computations()) {
679 ResolveInconsistencyOfAliasingBuffersHelper(branch,
680 visited_computations);
681 }
682 }
683 }
684 if (!any_change) {
685 break;
686 }
687 }
688 // Now adjust parameters of called computations.
689 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
690 AdjustCalledComputationParameters(*inst_it);
691 }
692 return parameter_changed;
693 }
694
ResolveInconsistencyOfAliasingBuffers(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)695 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
696 HloModule* module,
697 const absl::flat_hash_set<absl::string_view>& execution_threads) {
698 const auto& computations_topological_order =
699 module->MakeComputationPostOrder(execution_threads);
700 absl::flat_hash_set<const HloComputation*> resolved;
701 for (auto comp_it = computations_topological_order.rbegin();
702 comp_it != computations_topological_order.rend(); ++comp_it) {
703 if (ContainsKey(resolved, *comp_it)) {
704 continue;
705 }
706 ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
707 }
708 }
709
ResolveInconsistentFusions(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)710 Status BFloat16Propagation::ResolveInconsistentFusions(
711 HloModule* module,
712 const absl::flat_hash_set<absl::string_view>& execution_threads) {
713 // We could have changed a fusion computation's root shape to have a different
714 // precision than the fusion node's output, if the fusion root does not
715 // define a buffer (e.g., a tuple). Now we add conversions after such fusion
716 // roots to make them match the fusion output. If the fusion output is a
717 // (possibly nested) tuple, we first create get-tuple-elements, then convert
718 // the unmatching leaf nodes, and finally create a new tuple as the fusion
719 // computation's root. If tuples and get-tuple-elements are created, we will
720 // run tuple simplifier and dead code elimination at the end (dead code is not
721 // allowed in fusion computation). E.g.,
722 //
723 // (1) (2) (3)
724 // a b a b a b
725 // |\ | |\ | |\ |
726 // \ add -> |add -> | add
727 // \ | \ | convert |
728 // tuple tuple \ |
729 // / \ tuple
730 // gte gte
731 // | |
732 // convert |
733 // \ /
734 // tuple
735 // (1) a is F32 but tuple is BF16
736 // (2) after adding conversion
737 // (3) after tuple simplifier and DCE.
738 for (auto computation : module->MakeComputationPostOrder(execution_threads)) {
739 auto insts = computation->MakeInstructionPostOrder();
740 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
741 auto hlo = *inst_it;
742 if (hlo->opcode() != HloOpcode::kFusion) {
743 continue;
744 }
745 auto fusion_computation = hlo->fused_instructions_computation();
746 auto fusion_root = fusion_computation->root_instruction();
747 if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
748 continue;
749 }
750 ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
751 // Deep copy the fusion root, and convert a leaf node only if its shape
752 // does not match the fusion output.
753 TF_ASSIGN_OR_RETURN(
754 HloInstruction * copy,
755 fusion_computation->DeepCopyInstructionWithCustomCopier(
756 fusion_root,
757 [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
758 HloComputation* comp) {
759 const Shape& hlo_subshape =
760 ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
761 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
762 return leaf;
763 }
764 return comp->AddInstruction(
765 HloInstruction::CreateConvert(hlo_subshape, leaf));
766 }));
767 fusion_computation->set_root_instruction(copy);
768 }
769 }
770 return OkStatus();
771 }
772
ResolveConvertedConstants(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)773 Status BFloat16Propagation::ResolveConvertedConstants(
774 HloModule* module,
775 const absl::flat_hash_set<absl::string_view>& execution_threads) {
776 // We may have converted some constants from F32 to BF16, so adjust the
777 // constant literals in such cases. We do this here instead of when the
778 // constant node's is changed because 1) the HloInstruction interface does not
779 // allow resetting the literal so we have to create a new kConstant
780 // instruction to replace the old one, which invalidates dataflow analysis,
781 // and 2) it's possible that a kConstant's output gets changed to BF16 at the
782 // beginning but later on adjusted back to F32, so converting literals here
783 // can avoid repeated conversions.
784 //
785 // TODO(b/73833576): Consider resetting literal in HloInstruction.
786 for (auto computation : module->MakeComputationPostOrder(execution_threads)) {
787 for (auto hlo : computation->MakeInstructionPostOrder()) {
788 if (hlo->opcode() != HloOpcode::kConstant) {
789 continue;
790 }
791 if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
792 hlo->shape())) {
793 TF_ASSIGN_OR_RETURN(auto converted_literal,
794 hlo->literal().ConvertToShape(hlo->shape()));
795 auto new_constant = computation->AddInstruction(
796 HloInstruction::CreateConstant(std::move(converted_literal)));
797 UpdateLayout(new_constant->mutable_shape());
798 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
799 }
800 }
801 }
802 return OkStatus();
803 }
804
SkipNoopConversions(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)805 Status BFloat16Propagation::SkipNoopConversions(
806 HloModule* module,
807 const absl::flat_hash_set<absl::string_view>& execution_threads) {
808 for (auto computation : module->computations(execution_threads)) {
809 for (auto hlo : computation->MakeInstructionPostOrder()) {
810 if (hlo->opcode() != HloOpcode::kConvert) {
811 continue;
812 }
813 auto source = hlo->mutable_operand(0);
814 if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
815 continue;
816 }
817 const bool is_root = hlo == computation->root_instruction();
818 TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
819 if (is_root) {
820 computation->set_root_instruction(source);
821 }
822 }
823 }
824 return OkStatus();
825 }
826
827 // The algorithm first does a forward pass (parameters to root) to determine a
828 // set of instructions to consider using bfloat16, then does a backward pass to
829 // determine the precisions of those instructions according to the need of
830 // their users. During the backward pass, the potential changes are stored in
831 // changes_to_bf16_ which are subject to further adjustments then applied to the
832 // HLOs.
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)833 StatusOr<bool> BFloat16Propagation::Run(
834 HloModule* module,
835 const absl::flat_hash_set<absl::string_view>& execution_threads) {
836 consider_using_bfloat16_.clear();
837 instructions_visited_in_backward_pass_.clear();
838 computations_visited_in_backward_pass_.clear();
839 values_that_must_be_kept_as_f32_.clear();
840 caller_counts_.clear();
841 changes_to_bf16_.clear();
842 changed_ = false;
843
844 auto computations_topological_order =
845 module->MakeComputationPostOrder(execution_threads);
846
847 // Before running the propagation pass, we insert copies (kConvert to the same
848 // type) of F32 inputs to while loops. This prevents other uses of the same
849 // input from aliasing the while loop input/output, so that there's greater
850 // chance to use BF16 inside the loop. If some of these added copies do not
851 // help, they will remain F32 after BF16 propagation and will be removed since
852 // they are no-ops.
853 for (auto computation : computations_topological_order) {
854 for (auto inst : computation->MakeInstructionPostOrder()) {
855 if (inst->opcode() != HloOpcode::kWhile) {
856 continue;
857 }
858
859 auto operand = inst->mutable_operand(0);
860 TF_ASSIGN_OR_RETURN(
861 HloInstruction * copy,
862 computation->DeepCopyInstructionWithCustomCopier(
863 operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
864 HloComputation* comp) {
865 if (leaf->shape().element_type() != F32) {
866 return leaf;
867 }
868 return comp->AddInstruction(
869 HloInstruction::CreateConvert(leaf->shape(), leaf));
870 }));
871 TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
872 }
873 }
874
875 TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
876
877 // The first step is a forward pass (parameters to root), where we determine
878 // the potential candidate instructions to use bfloat16 in the outputs that
879 // are not likely to cause overhead from extra explicit conversions. This is
880 // done forwardly because we determine whether an HLO is a candidate partially
881 // based on whether its operands are candidates.
882 for (auto computation : computations_topological_order) {
883 for (auto inst : computation->MakeInstructionPostOrder()) {
884 if (InstructionIsCandidateForBF16Output(inst)) {
885 consider_using_bfloat16_.insert(inst);
886 }
887 }
888 }
889
890 // The second step is a backward pass (root to parameters), where we modify
891 // the precisions of the instructions identified in the first step when
892 // feasible. This is done backwardly because we determine the precision of an
893 // HLO's output based on how it is later used.
894 //
895 // The precision of an instruction is determined by its users, so we do the
896 // propagation in reverse topological order.
897 for (auto comp_it = computations_topological_order.rbegin();
898 comp_it != computations_topological_order.rend(); ++comp_it) {
899 if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
900 continue;
901 }
902 auto insts = (*comp_it)->MakeInstructionPostOrder();
903 for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
904 DetermineInstructionPrecision(*inst_it,
905 /*skip_parameters=*/true);
906 }
907 computations_visited_in_backward_pass_.insert(*comp_it);
908 }
909
910 // It's possible that an instruction does not define a buffer, but the
911 // defining instruction's shape has changed. So we need to adjust the output
912 // shapes of instructions according to the HLO values they refer to.
913 ResolveInconsistencyOfAliasingBuffers(module, execution_threads);
914
915 // Apply the changes in changes_to_bf16_.
916 for (auto& change : changes_to_bf16_) {
917 auto inst = change.first;
918 // It is possible that we marked inst to change precision even if it is an
919 // unsupported change, when inst is the root of a fusion computation and it
920 // has to match the fusion node's output precision. We do a convert instead
921 // of in-place change for such cases.
922 if (ShouldKeepPrecisionUnchanged(inst)) {
923 auto users = inst->users();
924 bool is_root = inst == inst->parent()->root_instruction();
925 TF_ASSIGN_OR_RETURN(
926 HloInstruction * copy,
927 inst->parent()->DeepCopyInstructionWithCustomCopier(
928 inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
929 HloComputation* comp) {
930 if (!ContainsKey(change.second,
931 ShapeUtil::GetMutableSubshape(
932 inst->mutable_shape(), leaf_index))) {
933 return leaf;
934 }
935 auto converted_shape =
936 ShapeUtil::ChangeElementType(leaf->shape(), BF16);
937 UpdateLayout(&converted_shape);
938 return comp->AddInstruction(
939 HloInstruction::CreateConvert(converted_shape, leaf));
940 }));
941 for (auto user : users) {
942 TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy));
943 }
944 if (is_root) {
945 inst->parent()->set_root_instruction(copy,
946 /*accept_different_shape=*/true);
947 }
948 continue;
949 }
950 for (const auto& entry : change.second) {
951 auto subshape = entry.first;
952 CHECK_EQ(subshape->element_type(), F32);
953 subshape->set_element_type(BF16);
954 UpdateLayout(subshape);
955 changed_ = true;
956 }
957 }
958
959 // Removes redundant HLOs added by this pass, either when inserting
960 // de-aliasing copies to while loop inputs, or later when converting output
961 // types.
962 auto clean_up = [this, module, &execution_threads]() {
963 TF_RETURN_IF_ERROR(SkipNoopConversions(module, execution_threads));
964 TupleSimplifier tuple_simplifier;
965 TF_RETURN_IF_ERROR(
966 tuple_simplifier.Run(module, execution_threads).status());
967 HloDCE dce;
968 TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status());
969 return OkStatus();
970 };
971
972 if (!changed_) {
973 TF_RETURN_IF_ERROR(clean_up());
974 return false;
975 }
976
977 TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module, execution_threads));
978 TF_RETURN_IF_ERROR(ResolveConvertedConstants(module, execution_threads));
979
980 TF_RETURN_IF_ERROR(clean_up());
981 return true;
982 }
983
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const984 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
985 HloInstruction* hlo, const ShapeIndex& index) const {
986 Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
987 const PrimitiveType type_on_hlo = subshape->element_type();
988 if (type_on_hlo != F32) {
989 return type_on_hlo;
990 }
991 auto it = changes_to_bf16_.find(hlo);
992 if (it == changes_to_bf16_.end()) {
993 return type_on_hlo;
994 }
995 return ContainsKey(it->second, subshape) ? BF16 : F32;
996 }
997
ValueTypeAfterChange(const HloValue * value) const998 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
999 const HloValue* value) const {
1000 auto hlo = value->defining_instruction();
1001 const auto& position = value->defining_position();
1002 return OutputTypeAfterChange(hlo, position.index);
1003 }
1004
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)1005 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
1006 HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
1007 if (target_type == BF16) {
1008 auto& entry = changes_to_bf16_[hlo];
1009 entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
1010 index);
1011 } else {
1012 CHECK_EQ(target_type, F32);
1013 auto it = changes_to_bf16_.find(hlo);
1014 if (it == changes_to_bf16_.end()) {
1015 return;
1016 }
1017 it->second.erase(
1018 ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
1019 }
1020 }
1021
1022 } // namespace xla
1023