xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/bfloat16_propagation.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/cleanup/cleanup.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/compiler/xla/literal.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_dce.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
29 #include "tensorflow/compiler/xla/shape_tree.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
BFloat16Propagation(const BFloat16Support * bfloat16_support)35 BFloat16Propagation::BFloat16Propagation(
36     const BFloat16Support* bfloat16_support)
37     : bfloat16_support_(bfloat16_support) {}
38 
DetermineFusionComputationPrecision(HloInstruction * fusion)39 void BFloat16Propagation::DetermineFusionComputationPrecision(
40     HloInstruction* fusion) {
41   CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
42   if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
43     return;
44   }
45 
46   // We are depending on the fusion node itself having already been analyzed
47   // for whether it can output BF16 and this has been adjusted in the output
48   // shape, and now we're looking to update the interior of the fusion node to
49   // match the new output shape, as well as recursively process the whole fusion
50   // node even if the output shape was not modified.
51   auto root = fusion->fused_instructions_computation()->root_instruction();
52 
53   // Adjust root's element types according to the fusion's output shape.
54   ShapeUtil::ForEachSubshape(
55       root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
56         if (subshape.element_type() != F32) {
57           return;
58         }
59         if (OutputTypeAfterChange(fusion, index) == BF16) {
60           AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
61           VLOG(2) << "Fused root " << root->ToString() << " at shape index "
62                   << index << " changed to BF16 precision for fusion "
63                   << fusion->ToString();
64         }
65       });
66 
67   // Propagate BF16 in the fusion computation.
68   auto insts =
69       fusion->fused_instructions_computation()->MakeInstructionPostOrder();
70   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
71     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
72   }
73   computations_visited_in_backward_pass_.insert(
74       fusion->fused_instructions_computation());
75 
76   RevertIfFusionInternalBF16Changes(fusion);
77 }
78 
RevertIfFusionInternalBF16Changes(HloInstruction * fusion)79 void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
80     HloInstruction* fusion) {
81   auto has_changes = [this](HloInstruction* inst) {
82     auto it = changes_to_bf16_.find(inst);
83     return it != changes_to_bf16_.end() && !it->second.empty();
84   };
85 
86   auto root = fusion->fused_instructions_computation()->root_instruction();
87   absl::flat_hash_set<const HloValue*> changed_root_buffers;
88 
89   auto root_changes_it = changes_to_bf16_.find(root);
90   if (root_changes_it != changes_to_bf16_.end()) {
91     for (const auto& entry : root_changes_it->second) {
92       for (const HloValue* value :
93            dataflow_->GetValueSet(root, entry.second).values()) {
94         changed_root_buffers.insert(value);
95       }
96     }
97   }
98 
99   auto aliases_changed_root_buffer = [this, &changed_root_buffers](
100                                          const HloInstruction* inst) {
101     bool aliasing = false;
102     ShapeUtil::ForEachSubshape(inst->shape(), [&](const Shape& subshape,
103                                                   const ShapeIndex& index) {
104       if (aliasing) {
105         // Skip if aliasing is already found.
106         return;
107       }
108       // Only F32 buffers are considered for changing to BF16 in this
109       // pass.
110       if (subshape.element_type() != F32) {
111         return;
112       }
113 
114       aliasing = absl::c_any_of(dataflow_->GetValueSet(inst, index).values(),
115                                 IsValueIn(changed_root_buffers));
116     });
117     return aliasing;
118   };
119 
120   for (auto inst :
121        fusion->fused_instructions_computation()->MakeInstructionPostOrder()) {
122     if (inst->opcode() == HloOpcode::kParameter) {
123       continue;
124     }
125     if (aliases_changed_root_buffer(inst)) {
126       continue;
127     }
128     if (inst->opcode() == HloOpcode::kFusion) {
129       bool parameter_reverted = false;
130       for (int64_t i = 0; i < inst->operand_count(); ++i) {
131         if (has_changes(inst->mutable_operand(i))) {
132           // Changes on the operand have not been reverted.
133           continue;
134         }
135         auto* fused_parameter = inst->fused_parameter(i);
136         if (has_changes(fused_parameter)) {
137           changes_to_bf16_.erase(fused_parameter);
138           parameter_reverted = true;
139         }
140       }
141       if (parameter_reverted) {
142         RevertIfFusionInternalBF16Changes(inst);
143       }
144     }
145     if (!has_changes(inst)) {
146       continue;
147     }
148     bool revert_changes = true;
149     for (auto operand : inst->operands()) {
150       if (has_changes(operand)) {
151         revert_changes = false;
152         break;
153       }
154     }
155     if (revert_changes) {
156       changes_to_bf16_.erase(inst);
157     }
158   }
159 }
160 
DetermineWhileComputationsPrecision(HloInstruction * while_hlo)161 void BFloat16Propagation::DetermineWhileComputationsPrecision(
162     HloInstruction* while_hlo) {
163   CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
164 
165   // We are depending on the while node itself having already been analyzed for
166   // whether it can output BF16 and this has been adjusted in the output shape,
167   // and now we're looking to update the body and condition computations to
168   // match the new output shape, as well as recursively process the whole while
169   // node even if the output shape was not modified.
170   HloComputation* body = while_hlo->while_body();
171   auto body_root = body->root_instruction();
172   HloComputation* condition = while_hlo->while_condition();
173 
174   ShapeUtil::ForEachSubshape(
175       body_root->shape(), [this, while_hlo, body_root](
176                               const Shape& subshape, const ShapeIndex& index) {
177         if (subshape.element_type() != F32) {
178           return;
179         }
180         if (OutputTypeAfterChange(while_hlo, index) == BF16) {
181           AddToOrRemoveFromBF16ChangeSet(body_root, index, BF16);
182           VLOG(2) << "While body root " << body_root->ToString()
183                   << " at shape index " << index
184                   << " changed to BF16 precision for while "
185                   << while_hlo->ToString();
186         }
187       });
188 
189   auto body_insts = body->MakeInstructionPostOrder();
190   for (auto inst_it = body_insts.rbegin(); inst_it != body_insts.rend();
191        ++inst_it) {
192     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
193   }
194   computations_visited_in_backward_pass_.insert(body);
195 
196   auto condition_insts = condition->MakeInstructionPostOrder();
197   for (auto inst_it = condition_insts.rbegin();
198        inst_it != condition_insts.rend(); ++inst_it) {
199     DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
200   }
201   computations_visited_in_backward_pass_.insert(condition);
202 }
203 
DetermineConditionalComputationsPrecision(HloInstruction * cond)204 void BFloat16Propagation::DetermineConditionalComputationsPrecision(
205     HloInstruction* cond) {
206   CHECK_EQ(cond->opcode(), HloOpcode::kConditional);
207   for (int64_t i = 0; i < cond->branch_count(); ++i) {
208     auto branch = cond->branch_computation(i);
209     auto root = branch->root_instruction();
210     ShapeUtil::ForEachSubshape(
211         root->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
212           if (subshape.element_type() != F32) {
213             return;
214           }
215           if (OutputTypeAfterChange(cond, index) == BF16) {
216             AddToOrRemoveFromBF16ChangeSet(root, index, BF16);
217             VLOG(2) << "Conditional branch " << i << " root "
218                     << root->ToString() << " at shape index " << index
219                     << " changed to BF16 precision for conditional "
220                     << cond->ToString();
221           }
222         });
223     auto insts = branch->MakeInstructionPostOrder();
224     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
225       DetermineInstructionPrecision(*inst_it, /*skip_parameters=*/false);
226     }
227     computations_visited_in_backward_pass_.insert(branch);
228   }
229 }
230 
AllUsersConsumeBF16(const HloInstruction & hlo,const ShapeIndex & index) const231 bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
232                                               const ShapeIndex& index) const {
233   // If the subshape isn't floating point then none of the users will be BF16.
234   const Shape& subshape = ShapeUtil::GetSubshape(hlo.shape(), index);
235   if (subshape.element_type() != BF16 && subshape.element_type() != F32) {
236     return false;
237   }
238 
239   auto& value_set = dataflow_->GetValueSet(&hlo, index);
240   for (const HloValue* value : value_set.values()) {
241     if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
242       return false;
243     }
244     // We use the original type for the value because we are going to examine
245     // the uses of it, instead of the value itself. If ValueTypeAfterChange()
246     // were used, it would cause problems when there are aliasing buffers, i.e.,
247     // ResolveInconsistencyOfAliasingBuffers() would fail to revert the
248     // tentative change to BF16 even if the uses require F32.
249     if (value->shape().element_type() == BF16) {
250       continue;
251     }
252     for (const HloUse& use : value->GetUses()) {
253       if (!ContainsKey(instructions_visited_in_backward_pass_,
254                        use.instruction)) {
255         // We don't know yet whether use.instruction will consume BF16 since it
256         // hasn't been visited. Although we visit instructions in reverse
257         // topological order, this is still possible because there may be
258         // unvisited instruction that alias the same buffer. In this case, we
259         // aggressively skip this use, and if this causes inconsistency (e.g.,
260         // one use is in BF16 but another use is in F32), it will be resolved at
261         // the end of the BFloat16Propagation pass.
262         continue;
263       }
264       if (use.instruction->HasSideEffectNoRecurse()) {
265         // Keep side-effecting instruction's operands unchanged.
266         return false;
267       }
268       // Any visited user that can accept BF16 has already been updated if
269       // necessary, e.g., the output has been changed to BF16 if it propagates
270       // precision, or a called computation's parameters have been changed to
271       // BF16 for fusions or whiles.
272       if (use.instruction->opcode() == HloOpcode::kFusion) {
273         auto* fused_parameter =
274             use.instruction->fused_parameter(use.operand_number);
275         if (OutputTypeAfterChange(fused_parameter, use.operand_index) != BF16) {
276           return false;
277         }
278         continue;
279       } else if (use.instruction->opcode() == HloOpcode::kWhile) {
280         auto* cond_parameter =
281             use.instruction->while_condition()->parameter_instruction(
282                 use.operand_number);
283         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
284           return false;
285         }
286         auto* body_parameter =
287             use.instruction->while_body()->parameter_instruction(
288                 use.operand_number);
289         if (OutputTypeAfterChange(body_parameter, use.operand_index) != BF16) {
290           return false;
291         }
292         continue;
293       } else if (use.instruction->opcode() == HloOpcode::kConditional) {
294         auto* cond_parameter =
295             use.instruction->branch_computation(use.operand_number - 1)
296                 ->parameter_instruction(0);
297         if (OutputTypeAfterChange(cond_parameter, use.operand_index) != BF16) {
298           return false;
299         }
300         continue;
301       }
302       if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
303               *use.instruction, use.operand_number)) {
304         continue;
305       }
306       // If the op propagates precision and it outputs a BF16, then it's OK to
307       // supply BF16 also as the input. In the backward pass, the users shapes
308       // should have already been processed.
309       if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
310               *use.instruction, use.operand_number)) {
311         if (use.instruction->opcode() == HloOpcode::kTuple ||
312             (use.instruction->opcode() == HloOpcode::kAllReduce &&
313              use.instruction->shape().IsTuple())) {
314           ShapeIndex use_output_index{use.operand_number};
315           for (int64_t i : use.operand_index) {
316             use_output_index.push_back(i);
317           }
318           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
319               BF16) {
320             continue;
321           }
322         } else if (use.instruction->opcode() == HloOpcode::kGetTupleElement) {
323           ShapeIndex use_output_index;
324           for (int64_t i = 1; i < use.operand_index.size(); ++i) {
325             use_output_index.push_back(use.operand_index[i]);
326           }
327           if (OutputTypeAfterChange(use.instruction, use_output_index) ==
328               BF16) {
329             continue;
330           }
331         } else {
332           if (OutputTypeAfterChange(use.instruction, use.operand_index) ==
333               BF16) {
334             continue;
335           }
336         }
337       }
338       return false;
339     }
340   }
341   return true;
342 }
343 
ShouldKeepPrecisionUnchanged(const HloInstruction * inst)344 bool BFloat16Propagation::ShouldKeepPrecisionUnchanged(
345     const HloInstruction* inst) {
346   if (inst->opcode() == HloOpcode::kFusion &&
347       inst->fusion_kind() == HloInstruction::FusionKind::kCustom) {
348     return ShouldKeepPrecisionUnchanged(
349         inst->fused_instructions_computation()->root_instruction());
350   }
351   // Do not change precision for side-effecting instructions, control flow, and
352   // bitcast-convert, because this pass might break the interfaces or
353   // assumptions for them.
354   return inst->opcode() == HloOpcode::kCustomCall ||
355          inst->opcode() == HloOpcode::kCall ||
356          inst->opcode() == HloOpcode::kBitcastConvert ||
357          inst->HasSideEffectNoRecurse();
358 }
359 
DetermineInstructionPrecision(HloInstruction * hlo,bool skip_parameters)360 void BFloat16Propagation::DetermineInstructionPrecision(HloInstruction* hlo,
361                                                         bool skip_parameters) {
362   // We handle any fusion computation, while body/condition or conditional
363   // branches after the instruction is handled, because we need to know the
364   // output shape of a fusion or while before propagating inside its
365   // computations.
366   bool postpone_processing_called_computations = false;
367   absl::Cleanup cleaner = [this, hlo,
368                            &postpone_processing_called_computations] {
369     if (!postpone_processing_called_computations) {
370       if (hlo->opcode() == HloOpcode::kFusion) {
371         DetermineFusionComputationPrecision(hlo);
372       } else if (hlo->opcode() == HloOpcode::kWhile) {
373         DetermineWhileComputationsPrecision(hlo);
374       } else if (hlo->opcode() == HloOpcode::kConditional) {
375         DetermineConditionalComputationsPrecision(hlo);
376       }
377     }
378     instructions_visited_in_backward_pass_.insert(hlo);
379   };
380 
381   if (hlo->opcode() == HloOpcode::kWhile &&
382       (caller_counts_[hlo->while_condition()] > 1 ||
383        caller_counts_[hlo->while_body()] > 1)) {
384     postpone_processing_called_computations = true;
385     return;
386   }
387 
388   if (hlo->opcode() == HloOpcode::kConditional &&
389       absl::c_any_of(hlo->branch_computations(), [&](const HloComputation* c) {
390         return caller_counts_[c] > 1;
391       })) {
392     postpone_processing_called_computations = true;
393     return;
394   }
395 
396   // Prevent root instructions from having their output modified by recording
397   // all F32 output values as needing to stay as F32.
398   CHECK(hlo->parent() != nullptr);
399   if (hlo == hlo->parent()->root_instruction()) {
400     if (!hlo->parent()->IsFusionComputation()) {
401       ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& /* subshape */,
402                                                    const ShapeIndex& index) {
403         if (OutputTypeAfterChange(hlo, index) != F32) {
404           return;
405         }
406         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
407           // Since we use HloValues from the dataflow analysis, this can also
408           // affect HLO instructions beyond the root, e.g., if the root is a
409           // Tuple HLO, then its operands are also affected.
410           values_that_must_be_kept_as_f32_.insert(value);
411         }
412       });
413     }
414     return;
415   }
416 
417   if (ShouldKeepPrecisionUnchanged(hlo) ||
418       (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
419     return;
420   }
421 
422   if (!ContainsKey(consider_using_bfloat16_, hlo)) {
423     return;
424   }
425 
426   if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
427     return;
428   }
429 
430   ShapeUtil::ForEachSubshape(
431       hlo->shape(),
432       [hlo, this](const Shape& /* subshape */, const ShapeIndex& index) {
433         if (OutputTypeAfterChange(hlo, index) == F32 &&
434             AllUsersConsumeBF16(*hlo, index)) {
435           AddToOrRemoveFromBF16ChangeSet(hlo, index, BF16);
436           VLOG(2) << "HloInstruction output at shape index " << index
437                   << " changed to BF16 precision: " << hlo->ToString();
438         }
439       });
440 }
441 
InstructionIsCandidateForBF16Output(HloInstruction * hlo)442 bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
443     HloInstruction* hlo) {
444   if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
445       hlo->opcode() != HloOpcode::kTuple &&
446       hlo->opcode() != HloOpcode::kGetTupleElement &&
447       hlo->opcode() != HloOpcode::kDomain &&
448       hlo->shape().element_type() != BF16) {
449     for (int64_t i = 0; i < hlo->operand_count(); ++i) {
450       if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
451                                                                          i) ||
452           !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
453         return false;
454       }
455     }
456   }
457   return true;
458 }
459 
AdjustCalledComputationParameters(HloInstruction * hlo)460 void BFloat16Propagation::AdjustCalledComputationParameters(
461     HloInstruction* hlo) {
462   auto adjust_computation = [this, hlo](
463                                 HloComputation* computation,
464                                 absl::Span<HloInstruction* const> operands) {
465     // Adjust parameters.
466     CHECK_EQ(operands.size(), computation->num_parameters());
467     for (int64_t i = 0; i < operands.size(); ++i) {
468       auto parameter = computation->parameter_instruction(i);
469       ShapeUtil::ForEachSubshape(
470           parameter->shape(),
471           [this, i, hlo, &operands, parameter](const Shape& /* subshape */,
472                                                const ShapeIndex& index) {
473             if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
474               return;
475             }
476             PrimitiveType operand_type =
477                 OutputTypeAfterChange(operands[i], index);
478             if (OutputTypeAfterChange(parameter, index) == operand_type) {
479               return;
480             }
481             AddToOrRemoveFromBF16ChangeSet(parameter, index, operand_type);
482             VLOG(2) << "Called computation parameter " << parameter->ToString()
483                     << " at shape index " << index << " adjusted to "
484                     << (operand_type == BF16 ? "BF16" : "F32")
485                     << " to match operand in HLO " << hlo->ToString();
486           });
487     }
488   };
489 
490   switch (hlo->opcode()) {
491     case HloOpcode::kFusion:
492       adjust_computation(hlo->fused_instructions_computation(),
493                          hlo->operands());
494       break;
495     case HloOpcode::kWhile:
496       adjust_computation(hlo->while_condition(), hlo->operands());
497       adjust_computation(hlo->while_body(), hlo->operands());
498       break;
499     case HloOpcode::kConditional:
500       for (int64_t i = 0; i < hlo->branch_count(); ++i) {
501         adjust_computation(hlo->branch_computation(i),
502                            {hlo->mutable_operand(i + 1)});
503       }
504       break;
505     default:
506       break;
507   }
508 }
509 
AdjustCalledComputationRoot(HloInstruction * hlo)510 void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
511   auto adjust_computation = [this, hlo](HloComputation* computation,
512                                         HloInstruction* output) {
513     // Adjust root.
514     HloInstruction* root = computation->root_instruction();
515     ShapeUtil::ForEachSubshape(root->shape(), [this, hlo, root, output](
516                                                   const Shape& /* subshape */,
517                                                   const ShapeIndex& index) {
518       if (!ShapeUtil::IsLeafIndex(hlo->shape(), index)) {
519         return;
520       }
521       const PrimitiveType output_type = OutputTypeAfterChange(output, index);
522       if (OutputTypeAfterChange(root, index) == output_type) {
523         return;
524       }
525       AddToOrRemoveFromBF16ChangeSet(root, index, output_type);
526       // It's possible that output_type is F32, but the root instruction's
527       // type is BF16; e.g., a fusion node's output was changed to BF16
528       // initially but then adjusted back to F32, and the fusion computation
529       // is now being adjusted after the fusion node.
530       if (output_type == F32) {
531         for (const auto* value : dataflow_->GetValueSet(root, index).values()) {
532           // We rely on the fact that this adjustment works in reverse
533           // topological order so that called computation will be
534           // processed later. Adding the value to
535           // values_that_must_be_kept_as_f32_ will ensure the
536           // correctness of the adjustment for HLOs that will be
537           // processed later.
538           values_that_must_be_kept_as_f32_.insert(value);
539         }
540       }
541       VLOG(2) << "Called computation root " << root->ToString()
542               << " at shape index " << index << " adjusted to "
543               << (output_type == BF16 ? "BF16" : "F32")
544               << " to match output shape of " << hlo->ToString();
545     });
546   };
547 
548   switch (hlo->opcode()) {
549     case HloOpcode::kFusion:
550       adjust_computation(hlo->fused_instructions_computation(), hlo);
551       break;
552     case HloOpcode::kWhile:
553       adjust_computation(hlo->while_body(), hlo);
554       break;
555     case HloOpcode::kConditional:
556       for (auto* branch : hlo->branch_computations()) {
557         adjust_computation(branch, hlo);
558       }
559       break;
560     default:
561       break;
562   }
563 }
564 
ResolveInconsistencyOfAliasingBuffersHelper(HloComputation * computation,absl::flat_hash_set<const HloComputation * > * visited_computations)565 bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
566     HloComputation* computation,
567     absl::flat_hash_set<const HloComputation*>* visited_computations) {
568   bool parameter_changed = false;
569   auto insts = computation->MakeInstructionPostOrder();
570   // Do the adjustment on each instruction in the computation in reverse
571   // topological order.
572   while (true) {
573     bool any_change = false;
574     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
575       auto hlo = *inst_it;
576       auto adjust_hlo_output = [&](const Shape& /* subshape */,
577                                    const ShapeIndex& index) {
578         auto output_type = OutputTypeAfterChange(hlo, index);
579         VLOG(2) << "output_type is " << ((output_type == BF16) ? "BF16" : "F32")
580                 << " for :" << hlo->ToString() << "\n";
581         if (output_type != F32 && output_type != BF16) {
582           return;
583         }
584         PrimitiveType type = BF16;
585         for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
586           auto value_type = ValueTypeAfterChange(value);
587           if (value_type == BF16) {
588             continue;
589           }
590           VLOG(2) << "Adjust to F32 due to aliased dataflow value: "
591                   << value->ToString() << "\n";
592           CHECK_EQ(value_type, F32);
593           type = F32;
594           break;
595         }
596         // In order to find aliases due to in-place operations, use
597         // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
598         // but this code works with HloModules that aren't ready yet to use
599         // HloAliasAnalysis (e.g., their computation graphs may not have been
600         // flattened yet).
601         for (const auto& operand_and_output_index :
602              HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
603           if (operand_and_output_index.second == index) {
604             const HloOperandIndex& operand_index =
605                 operand_and_output_index.first;
606             for (const auto* value :
607                  dataflow_
608                      ->GetValueSet(hlo->operand(operand_index.operand_number),
609                                    operand_index.operand_index)
610                      .values()) {
611               auto value_type = ValueTypeAfterChange(value);
612               if (value_type == BF16) {
613                 continue;
614               }
615               VLOG(2) << "Adjust to F32 due to InputOutPair: "
616                       << value->ToString() << "\n";
617               CHECK_EQ(value_type, F32);
618               type = F32;
619               break;
620             }
621           }
622         }
623 
624         // It's possible that a user has been changed from BF16 to F32
625         // during this final adjustment pass, so we need to check
626         // AllUsersConsumeBF16() again.
627         if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
628           VLOG(2) << "Adjust to F32 due to All user consumeBF16 fail\n";
629           type = F32;
630         }
631         if (type == F32) {
632           for (const auto* value :
633                dataflow_->GetValueSet(hlo, index).values()) {
634             // We rely on the fact that this adjustment works in reverse
635             // topological order. Adding the value to
636             // values_that_must_be_kept_as_f32_ will ensure the correctness
637             // of the adjustment for HLOs that will be processed later.
638             values_that_must_be_kept_as_f32_.insert(value);
639           }
640         }
641         if (type != output_type) {
642           any_change = true;
643           AddToOrRemoveFromBF16ChangeSet(hlo, index, type);
644           VLOG(2) << "HloInstruction output at shape index " << index
645                   << " adjusted to " << (type == BF16 ? "BF16" : "F32") << ": "
646                   << hlo->ToString();
647           if (hlo->opcode() == HloOpcode::kParameter) {
648             parameter_changed = true;
649           }
650         }
651       };
652       ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
653       AdjustCalledComputationRoot(hlo);
654       if (hlo->opcode() == HloOpcode::kWhile) {
655         // We need to run on the while body and condition repeatedly until a
656         // fixed point is reached, i.e., the parameters do not change any more.
657         // We may need more than one iteration because the while input and
658         // output alias each other, so changing one input parameter requires
659         // changing the corresponding output element and thus may transitively
660         // require changing another input parameter. A fixed point will be
661         // reached because the parameters can only be changed from BF16 to F32,
662         // not the other way around.
663         absl::flat_hash_set<const HloComputation*> visited_in_while;
664         while (ResolveInconsistencyOfAliasingBuffersHelper(
665                    hlo->while_condition(), &visited_in_while) ||
666                ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
667                                                            &visited_in_while)) {
668           visited_in_while.clear();
669           ShapeUtil::ForEachSubshape(hlo->shape(), adjust_hlo_output);
670           AdjustCalledComputationRoot(hlo);
671         }
672         visited_computations->insert(visited_in_while.begin(),
673                                      visited_in_while.end());
674       } else if (hlo->opcode() == HloOpcode::kFusion) {
675         ResolveInconsistencyOfAliasingBuffersHelper(
676             hlo->fused_instructions_computation(), visited_computations);
677       } else if (hlo->opcode() == HloOpcode::kConditional) {
678         for (auto* branch : hlo->branch_computations()) {
679           ResolveInconsistencyOfAliasingBuffersHelper(branch,
680                                                       visited_computations);
681         }
682       }
683     }
684     if (!any_change) {
685       break;
686     }
687   }
688   // Now adjust parameters of called computations.
689   for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
690     AdjustCalledComputationParameters(*inst_it);
691   }
692   return parameter_changed;
693 }
694 
ResolveInconsistencyOfAliasingBuffers(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)695 void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
696     HloModule* module,
697     const absl::flat_hash_set<absl::string_view>& execution_threads) {
698   const auto& computations_topological_order =
699       module->MakeComputationPostOrder(execution_threads);
700   absl::flat_hash_set<const HloComputation*> resolved;
701   for (auto comp_it = computations_topological_order.rbegin();
702        comp_it != computations_topological_order.rend(); ++comp_it) {
703     if (ContainsKey(resolved, *comp_it)) {
704       continue;
705     }
706     ResolveInconsistencyOfAliasingBuffersHelper(*comp_it, &resolved);
707   }
708 }
709 
ResolveInconsistentFusions(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)710 Status BFloat16Propagation::ResolveInconsistentFusions(
711     HloModule* module,
712     const absl::flat_hash_set<absl::string_view>& execution_threads) {
713   // We could have changed a fusion computation's root shape to have a different
714   // precision than the fusion node's output, if the fusion root does not
715   // define a buffer (e.g., a tuple). Now we add conversions after such fusion
716   // roots to make them match the fusion output. If the fusion output is a
717   // (possibly nested) tuple, we first create get-tuple-elements, then convert
718   // the unmatching leaf nodes, and finally create a new tuple as the fusion
719   // computation's root. If tuples and get-tuple-elements are created, we will
720   // run tuple simplifier and dead code elimination at the end (dead code is not
721   // allowed in fusion computation). E.g.,
722   //
723   // (1)             (2)             (3)
724   // a  b            a  b            a  b
725   // |\ |            |\ |            |\ |
726   // \ add   ->      |add    ->      | add
727   //  \ |            \ |        convert |
728   //  tuple         tuple             \ |
729   //                 / \              tuple
730   //               gte gte
731   //                |   |
732   //           convert  |
733   //                 \  /
734   //                 tuple
735   // (1) a is F32 but tuple is BF16
736   // (2) after adding conversion
737   // (3) after tuple simplifier and DCE.
738   for (auto computation : module->MakeComputationPostOrder(execution_threads)) {
739     auto insts = computation->MakeInstructionPostOrder();
740     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
741       auto hlo = *inst_it;
742       if (hlo->opcode() != HloOpcode::kFusion) {
743         continue;
744       }
745       auto fusion_computation = hlo->fused_instructions_computation();
746       auto fusion_root = fusion_computation->root_instruction();
747       if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) {
748         continue;
749       }
750       ShapeTree<HloInstruction*> converted_outputs(hlo->shape());
751       // Deep copy the fusion root, and convert a leaf node only if its shape
752       // does not match the fusion output.
753       TF_ASSIGN_OR_RETURN(
754           HloInstruction * copy,
755           fusion_computation->DeepCopyInstructionWithCustomCopier(
756               fusion_root,
757               [hlo](HloInstruction* leaf, const ShapeIndex& leaf_index,
758                     HloComputation* comp) {
759                 const Shape& hlo_subshape =
760                     ShapeUtil::GetSubshape(hlo->shape(), leaf_index);
761                 if (ShapeUtil::Compatible(leaf->shape(), hlo_subshape)) {
762                   return leaf;
763                 }
764                 return comp->AddInstruction(
765                     HloInstruction::CreateConvert(hlo_subshape, leaf));
766               }));
767       fusion_computation->set_root_instruction(copy);
768     }
769   }
770   return OkStatus();
771 }
772 
ResolveConvertedConstants(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)773 Status BFloat16Propagation::ResolveConvertedConstants(
774     HloModule* module,
775     const absl::flat_hash_set<absl::string_view>& execution_threads) {
776   // We may have converted some constants from F32 to BF16, so adjust the
777   // constant literals in such cases. We do this here instead of when the
778   // constant node's is changed because 1) the HloInstruction interface does not
779   // allow resetting the literal so we have to create a new kConstant
780   // instruction to replace the old one, which invalidates dataflow analysis,
781   // and 2) it's possible that a kConstant's output gets changed to BF16 at the
782   // beginning but later on adjusted back to F32, so converting literals here
783   // can avoid repeated conversions.
784   //
785   // TODO(b/73833576): Consider resetting literal in HloInstruction.
786   for (auto computation : module->MakeComputationPostOrder(execution_threads)) {
787     for (auto hlo : computation->MakeInstructionPostOrder()) {
788       if (hlo->opcode() != HloOpcode::kConstant) {
789         continue;
790       }
791       if (!Shape::Equal().MinorToMajorOnlyInLayout()(hlo->literal().shape(),
792                                                      hlo->shape())) {
793         TF_ASSIGN_OR_RETURN(auto converted_literal,
794                             hlo->literal().ConvertToShape(hlo->shape()));
795         auto new_constant = computation->AddInstruction(
796             HloInstruction::CreateConstant(std::move(converted_literal)));
797         UpdateLayout(new_constant->mutable_shape());
798         TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_constant));
799       }
800     }
801   }
802   return OkStatus();
803 }
804 
SkipNoopConversions(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)805 Status BFloat16Propagation::SkipNoopConversions(
806     HloModule* module,
807     const absl::flat_hash_set<absl::string_view>& execution_threads) {
808   for (auto computation : module->computations(execution_threads)) {
809     for (auto hlo : computation->MakeInstructionPostOrder()) {
810       if (hlo->opcode() != HloOpcode::kConvert) {
811         continue;
812       }
813       auto source = hlo->mutable_operand(0);
814       if (!ShapeUtil::Equal(source->shape(), hlo->shape())) {
815         continue;
816       }
817       const bool is_root = hlo == computation->root_instruction();
818       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(source));
819       if (is_root) {
820         computation->set_root_instruction(source);
821       }
822     }
823   }
824   return OkStatus();
825 }
826 
827 // The algorithm first does a forward pass (parameters to root) to determine a
828 // set of instructions to consider using bfloat16, then does a backward pass to
829 // determine the precisions of those instructions according to the need of
830 // their users. During the backward pass, the potential changes are stored in
831 // changes_to_bf16_ which are subject to further adjustments then applied to the
832 // HLOs.
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)833 StatusOr<bool> BFloat16Propagation::Run(
834     HloModule* module,
835     const absl::flat_hash_set<absl::string_view>& execution_threads) {
836   consider_using_bfloat16_.clear();
837   instructions_visited_in_backward_pass_.clear();
838   computations_visited_in_backward_pass_.clear();
839   values_that_must_be_kept_as_f32_.clear();
840   caller_counts_.clear();
841   changes_to_bf16_.clear();
842   changed_ = false;
843 
844   auto computations_topological_order =
845       module->MakeComputationPostOrder(execution_threads);
846 
847   // Before running the propagation pass, we insert copies (kConvert to the same
848   // type) of F32 inputs to while loops. This prevents other uses of the same
849   // input from aliasing the while loop input/output, so that there's greater
850   // chance to use BF16 inside the loop. If some of these added copies do not
851   // help, they will remain F32 after BF16 propagation and will be removed since
852   // they are no-ops.
853   for (auto computation : computations_topological_order) {
854     for (auto inst : computation->MakeInstructionPostOrder()) {
855       if (inst->opcode() != HloOpcode::kWhile) {
856         continue;
857       }
858 
859       auto operand = inst->mutable_operand(0);
860       TF_ASSIGN_OR_RETURN(
861           HloInstruction * copy,
862           computation->DeepCopyInstructionWithCustomCopier(
863               operand, [](HloInstruction* leaf, const ShapeIndex& leaf_index,
864                           HloComputation* comp) {
865                 if (leaf->shape().element_type() != F32) {
866                   return leaf;
867                 }
868                 return comp->AddInstruction(
869                     HloInstruction::CreateConvert(leaf->shape(), leaf));
870               }));
871       TF_RETURN_IF_ERROR(operand->ReplaceUseWith(inst, copy));
872     }
873   }
874 
875   TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
876 
877   // The first step is a forward pass (parameters to root), where we determine
878   // the potential candidate instructions to use bfloat16 in the outputs that
879   // are not likely to cause overhead from extra explicit conversions. This is
880   // done forwardly because we determine whether an HLO is a candidate partially
881   // based on whether its operands are candidates.
882   for (auto computation : computations_topological_order) {
883     for (auto inst : computation->MakeInstructionPostOrder()) {
884       if (InstructionIsCandidateForBF16Output(inst)) {
885         consider_using_bfloat16_.insert(inst);
886       }
887     }
888   }
889 
890   // The second step is a backward pass (root to parameters), where we modify
891   // the precisions of the instructions identified in the first step when
892   // feasible. This is done backwardly because we determine the precision of an
893   // HLO's output based on how it is later used.
894   //
895   // The precision of an instruction is determined by its users, so we do the
896   // propagation in reverse topological order.
897   for (auto comp_it = computations_topological_order.rbegin();
898        comp_it != computations_topological_order.rend(); ++comp_it) {
899     if (ContainsKey(computations_visited_in_backward_pass_, *comp_it)) {
900       continue;
901     }
902     auto insts = (*comp_it)->MakeInstructionPostOrder();
903     for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
904       DetermineInstructionPrecision(*inst_it,
905                                     /*skip_parameters=*/true);
906     }
907     computations_visited_in_backward_pass_.insert(*comp_it);
908   }
909 
910   // It's possible that an instruction does not define a buffer, but the
911   // defining instruction's shape has changed. So we need to adjust the output
912   // shapes of instructions according to the HLO values they refer to.
913   ResolveInconsistencyOfAliasingBuffers(module, execution_threads);
914 
915   // Apply the changes in changes_to_bf16_.
916   for (auto& change : changes_to_bf16_) {
917     auto inst = change.first;
918     // It is possible that we marked inst to change precision even if it is an
919     // unsupported change, when inst is the root of a fusion computation and it
920     // has to match the fusion node's output precision. We do a convert instead
921     // of in-place change for such cases.
922     if (ShouldKeepPrecisionUnchanged(inst)) {
923       auto users = inst->users();
924       bool is_root = inst == inst->parent()->root_instruction();
925       TF_ASSIGN_OR_RETURN(
926           HloInstruction * copy,
927           inst->parent()->DeepCopyInstructionWithCustomCopier(
928               inst, [&](HloInstruction* leaf, const ShapeIndex& leaf_index,
929                         HloComputation* comp) {
930                 if (!ContainsKey(change.second,
931                                  ShapeUtil::GetMutableSubshape(
932                                      inst->mutable_shape(), leaf_index))) {
933                   return leaf;
934                 }
935                 auto converted_shape =
936                     ShapeUtil::ChangeElementType(leaf->shape(), BF16);
937                 UpdateLayout(&converted_shape);
938                 return comp->AddInstruction(
939                     HloInstruction::CreateConvert(converted_shape, leaf));
940               }));
941       for (auto user : users) {
942         TF_RETURN_IF_ERROR(inst->ReplaceUseWithDifferentShape(user, copy));
943       }
944       if (is_root) {
945         inst->parent()->set_root_instruction(copy,
946                                              /*accept_different_shape=*/true);
947       }
948       continue;
949     }
950     for (const auto& entry : change.second) {
951       auto subshape = entry.first;
952       CHECK_EQ(subshape->element_type(), F32);
953       subshape->set_element_type(BF16);
954       UpdateLayout(subshape);
955       changed_ = true;
956     }
957   }
958 
959   // Removes redundant HLOs added by this pass, either when inserting
960   // de-aliasing copies to while loop inputs, or later when converting output
961   // types.
962   auto clean_up = [this, module, &execution_threads]() {
963     TF_RETURN_IF_ERROR(SkipNoopConversions(module, execution_threads));
964     TupleSimplifier tuple_simplifier;
965     TF_RETURN_IF_ERROR(
966         tuple_simplifier.Run(module, execution_threads).status());
967     HloDCE dce;
968     TF_RETURN_IF_ERROR(dce.Run(module, execution_threads).status());
969     return OkStatus();
970   };
971 
972   if (!changed_) {
973     TF_RETURN_IF_ERROR(clean_up());
974     return false;
975   }
976 
977   TF_RETURN_IF_ERROR(ResolveInconsistentFusions(module, execution_threads));
978   TF_RETURN_IF_ERROR(ResolveConvertedConstants(module, execution_threads));
979 
980   TF_RETURN_IF_ERROR(clean_up());
981   return true;
982 }
983 
OutputTypeAfterChange(HloInstruction * hlo,const ShapeIndex & index) const984 PrimitiveType BFloat16Propagation::OutputTypeAfterChange(
985     HloInstruction* hlo, const ShapeIndex& index) const {
986   Shape* subshape = ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index);
987   const PrimitiveType type_on_hlo = subshape->element_type();
988   if (type_on_hlo != F32) {
989     return type_on_hlo;
990   }
991   auto it = changes_to_bf16_.find(hlo);
992   if (it == changes_to_bf16_.end()) {
993     return type_on_hlo;
994   }
995   return ContainsKey(it->second, subshape) ? BF16 : F32;
996 }
997 
ValueTypeAfterChange(const HloValue * value) const998 PrimitiveType BFloat16Propagation::ValueTypeAfterChange(
999     const HloValue* value) const {
1000   auto hlo = value->defining_instruction();
1001   const auto& position = value->defining_position();
1002   return OutputTypeAfterChange(hlo, position.index);
1003 }
1004 
AddToOrRemoveFromBF16ChangeSet(HloInstruction * hlo,const ShapeIndex & index,PrimitiveType target_type)1005 void BFloat16Propagation::AddToOrRemoveFromBF16ChangeSet(
1006     HloInstruction* hlo, const ShapeIndex& index, PrimitiveType target_type) {
1007   if (target_type == BF16) {
1008     auto& entry = changes_to_bf16_[hlo];
1009     entry.emplace(ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index),
1010                   index);
1011   } else {
1012     CHECK_EQ(target_type, F32);
1013     auto it = changes_to_bf16_.find(hlo);
1014     if (it == changes_to_bf16_.end()) {
1015       return;
1016     }
1017     it->second.erase(
1018         ShapeUtil::GetMutableSubshape(hlo->mutable_shape(), index));
1019   }
1020 }
1021 
1022 }  // namespace xla
1023