xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/bfloat16_propagation.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module.h"
28 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
29 
30 namespace xla {
31 
32 // HLO pass which reduces the precision of some HLO instructions to BF16
33 // according to the backend-specific BFloat16Support rule provided by the
34 // caller.
35 //
36 // This pass can be used to reduce instruction precision without affecting the
37 // numerical accuracy of the module, i.e., the final output of the module would
38 // be bitwise identical to that without this pass; this is possible if the
39 // backend already reduces precision to BF16 on some HLO instructions.
40 //
41 // This pass will not modify the signature of a computation, unless it is a
42 // fusion computation or its only caller is a while.
43 //
44 // !!! WARNING !!! This pass can introduce mixed precision in individual HLOs,
45 // which has two issues:
46 //
47 // 1) It does not guarantee to respect the passed-in BFloat16Support
48 // specification in terms of mixed precision, so the backend may not support an
49 // HLO that has mixed precision produced by this pass. To address this issue,
50 // run BFloat16Normalization with the same BFloat16Support after this pass.
51 //
52 // 2) In general, mixed precision may break the assumptions of some other HLO
53 // passes even if the specific backend supports the individual HLOs. Such
54 // assumptions include that there are no HLOs using mixed precision, or that the
55 // precision of an HLO's output is determined by its inputs. It should be used
56 // at the end of the HLO optimization pipeline but before
57 // BFloat16ConversionFolding. If other passes are needed after this pass, run
58 // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
59 // pass.
60 class BFloat16Propagation : public HloModulePass {
61  public:
62   explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
63 
64   ~BFloat16Propagation() override = default;
65 
name()66   absl::string_view name() const override { return "bfloat16-propagation"; }
67 
68   // Runs the pass on the given module. Returns whether the module was changed
69   // (precision reductions were added).
70   using HloPassInterface::Run;
71   StatusOr<bool> Run(
72       HloModule* module,
73       const absl::flat_hash_set<absl::string_view>& execution_threads) override;
74 
75   // Returns whether we should avoid changing the precision of inst regardless
76   // of the producers and users.
77   virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst);
78 
79   // Determines whether we should consider changing the precision of the given
80   // instruction in the forward pass.
81   virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo);
82 
83   // Returns whether we should avoid changing the precision of inst regardless
84   // of the producers and users.
85   virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst);
86 
87   // Determines whether we should consider changing the precision of the given
88   // instruction in the forward pass.
89   virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo);
90 
91  private:
92   // ***************************
93   // Function called and state produced by the forward analysis pass (from
94   // parameters to root) that determines the candidate HLOs to use BF16 outputs.
95 
96   // The set of instructions to consider using bfloat16, computed in the forward
97   // pass.
98   absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;
99 
100   // ***************************
101   // Functions called and state produced by the backward pass (from root to
102   // parameters) that finds opportunities to use BF16.
103 
104   // Determines the precision for the given instruction in the
105   // opportunity-finding pass.
106   void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters);
107 
108   // Special handling in the opportunity-finding pass for fusion computations.
109   //
110   // Precondition: hlo->opcode() == kFusion
111   void DetermineFusionComputationPrecision(HloInstruction* fusion);
112 
113   // Reverts changes to BF16 that will not propagate outside a fusion
114   // computation. This avoids BF16 casts overhead inside a fusion which won't
115   // save memory bandwidth.
116   //
117   // Precondition: hlo->opcode() == kFusion
118   void RevertIfFusionInternalBF16Changes(HloInstruction* fusion);
119 
120   // Special handling in the opportunity-finding pass for while computations.
121   //
122   // Precondition: hlo->opcode() == kWhile
123   void DetermineWhileComputationsPrecision(HloInstruction* while_hlo);
124 
125   // Special handling in the opportunity-finding pass for conditional branches.
126   //
127   // Precondition: hlo->opcode() == kConditional
128   void DetermineConditionalComputationsPrecision(HloInstruction* cond);
129 
130   // The set of HloInstructions that have been visited in the
131   // opportunity-finding pass.
132   absl::flat_hash_set<const HloInstruction*>
133       instructions_visited_in_backward_pass_;
134 
135   // The set of HloComputations that have been visited in the
136   // opportunity-finding pass.
137   absl::flat_hash_set<const HloComputation*>
138       computations_visited_in_backward_pass_;
139 
140   // ***************************
141   // Functions called by the final inconsistency resolving pass.
142 
143   // Adjusts the output shapes of HloInstructions such that if two
144   // HloInstructions have aliasing buffers in their outputs, they must have the
145   // same precision.
146   void ResolveInconsistencyOfAliasingBuffers(
147       HloModule* module,
148       const absl::flat_hash_set<absl::string_view>& execution_threads);
149 
150   // Resolves inconsistency of aliasing buffers for the given computation, and
151   // recursively runs on a while instruction's condition and body until a fixed
152   // point is reached.
153   bool ResolveInconsistencyOfAliasingBuffersHelper(
154       HloComputation* computation,
155       absl::flat_hash_set<const HloComputation*>* visited_computations);
156 
157   // Makes the parameters of called computations match how they are called by
158   // the given HLO.
159   void AdjustCalledComputationParameters(HloInstruction* hlo);
160 
161   // Makes the root instructions of called computations match how they are used
162   // by the given HLO.
163   void AdjustCalledComputationRoot(HloInstruction* hlo);
164 
165   // ***************************
166   // Functions called after changes in changes_to_bf16_ are applied.
167 
168   // Resolves inconsistencies introduced by this pass for fusions with
169   // tuple-type output.
170   Status ResolveInconsistentFusions(
171       HloModule* module,
172       const absl::flat_hash_set<absl::string_view>& execution_threads);
173 
174   // Converts the literals in kConstant HLOs which have their types changed to
175   // BF16 by this pass.
176   Status ResolveConvertedConstants(
177       HloModule* module,
178       const absl::flat_hash_set<absl::string_view>& execution_threads);
179 
180   // Skips no-op conversions (same source and target shapes) that can be
181   // produced this pass, i.e., replaces them in their uses with their operands.
182   Status SkipNoopConversions(
183       HloModule* module,
184       const absl::flat_hash_set<absl::string_view>& execution_threads);
185 
186   // ***************************
187   // Functions called and state used by two or more passes.
188 
189   // Returns whether all uses of the given HloInstruction can consume BF16
190   // input.
191   bool AllUsersConsumeBF16(const HloInstruction& hlo,
192                            const ShapeIndex& index) const;
193 
194   // The output element type of the HLO at the given shape index after changes
195   // in changes_to_bf16_ are applied.
196   PrimitiveType OutputTypeAfterChange(HloInstruction* hlo,
197                                       const ShapeIndex& index) const;
198 
199   // The element type of the HLO value after changes in changes_to_bf16_ are
200   // applied.
201   PrimitiveType ValueTypeAfterChange(const HloValue* value) const;
202 
203   // If target_type == BF16, adds the HLO at the given index to
204   // changes_to_bf16_; otherwise, target_type must be F32 and this function
205   // removes the HLO at the given index from changes_to_bf16_ if it was earlier
206   // added.
207   void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo,
208                                       const ShapeIndex& index,
209                                       PrimitiveType target_type);
210 
211   // The set of F32 HLO values that must be kept in F32.
212   absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;
213 
214   // Mapping from each HloComputation to the number of callers to it in the
215   // module. Populated at the beginning of this pass.
216   absl::flat_hash_map<const HloComputation*, int64_t> caller_counts_;
217 
218   // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
219   // are subject to further adjustment, then finally applied to the HLOs. This
220   // avoids setting changed_ to true but all changes are reverted during
221   // adjustment.
222   //
223   // For each HloInstruction, changes_to_bf16_ stores the affected buffers in
224   // the output as a map from in-place pointers to subshapes to shape indices.
225   absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>>
226       changes_to_bf16_;
227 
228   // Whether the last processed HLO module has been changed by this pass.
229   bool changed_ = false;
230 
231   const BFloat16Support* bfloat16_support_;
232   std::unique_ptr<HloDataflowAnalysis> dataflow_;
233 };
234 
235 }  // namespace xla
236 
237 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
238