1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "tensorflow/compiler/xla/service/bfloat16_support.h" 25 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_module.h" 28 #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" 29 30 namespace xla { 31 32 // HLO pass which reduces the precision of some HLO instructions to BF16 33 // according to the backend-specific BFloat16Support rule provided by the 34 // caller. 35 // 36 // This pass can be used to reduce instruction precision without affecting the 37 // numerical accuracy of the module, i.e., the final output of the module would 38 // be bitwise identical to that without this pass; this is possible if the 39 // backend already reduces precision to BF16 on some HLO instructions. 40 // 41 // This pass will not modify the signature of a computation, unless it is a 42 // fusion computation or its only caller is a while. 43 // 44 // !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, 45 // which has two issues: 46 // 47 // 1) It does not guarantee to respect the passed-in BFloat16Support 48 // specification in terms of mixed precision, so the backend may not support an 49 // HLO that has mixed precision produced by this pass. To address this issue, 50 // run BFloat16Normalization with the same BFloat16Support after this pass. 51 // 52 // 2) In general, mixed precision may break the assumptions of some other HLO 53 // passes even if the specific backend supports the individual HLOs. Such 54 // assumptions include that there are no HLOs using mixed precision, or that the 55 // precision of an HLO's output is determined by its inputs. It should be used 56 // at the end of the HLO optimization pipeline but before 57 // BFloat16ConversionFolding. If other passes are needed after this pass, run 58 // BFloat16MixedPrecisionRemoval first to undo some of the changes made by this 59 // pass. 60 class BFloat16Propagation : public HloModulePass { 61 public: 62 explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); 63 64 ~BFloat16Propagation() override = default; 65 name()66 absl::string_view name() const override { return "bfloat16-propagation"; } 67 68 // Runs the pass on the given module. Returns whether the module was changed 69 // (precision reductions were added). 70 using HloPassInterface::Run; 71 StatusOr<bool> Run( 72 HloModule* module, 73 const absl::flat_hash_set<absl::string_view>& execution_threads) override; 74 75 // Returns whether we should avoid changing the precision of inst regardless 76 // of the producers and users. 77 virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst); 78 79 // Determines whether we should consider changing the precision of the given 80 // instruction in the forward pass. 81 virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); 82 83 // Returns whether we should avoid changing the precision of inst regardless 84 // of the producers and users. 85 virtual bool ShouldKeepPrecisionUnchanged(const HloInstruction* inst); 86 87 // Determines whether we should consider changing the precision of the given 88 // instruction in the forward pass. 89 virtual bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); 90 91 private: 92 // *************************** 93 // Function called and state produced by the forward analysis pass (from 94 // parameters to root) that determines the candidate HLOs to use BF16 outputs. 95 96 // The set of instructions to consider using bfloat16, computed in the forward 97 // pass. 98 absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_; 99 100 // *************************** 101 // Functions called and state produced by the backward pass (from root to 102 // parameters) that finds opportunities to use BF16. 103 104 // Determines the precision for the given instruction in the 105 // opportunity-finding pass. 106 void DetermineInstructionPrecision(HloInstruction* hlo, bool skip_parameters); 107 108 // Special handling in the opportunity-finding pass for fusion computations. 109 // 110 // Precondition: hlo->opcode() == kFusion 111 void DetermineFusionComputationPrecision(HloInstruction* fusion); 112 113 // Reverts changes to BF16 that will not propagate outside a fusion 114 // computation. This avoids BF16 casts overhead inside a fusion which won't 115 // save memory bandwidth. 116 // 117 // Precondition: hlo->opcode() == kFusion 118 void RevertIfFusionInternalBF16Changes(HloInstruction* fusion); 119 120 // Special handling in the opportunity-finding pass for while computations. 121 // 122 // Precondition: hlo->opcode() == kWhile 123 void DetermineWhileComputationsPrecision(HloInstruction* while_hlo); 124 125 // Special handling in the opportunity-finding pass for conditional branches. 126 // 127 // Precondition: hlo->opcode() == kConditional 128 void DetermineConditionalComputationsPrecision(HloInstruction* cond); 129 130 // The set of HloInstructions that have been visited in the 131 // opportunity-finding pass. 132 absl::flat_hash_set<const HloInstruction*> 133 instructions_visited_in_backward_pass_; 134 135 // The set of HloComputations that have been visited in the 136 // opportunity-finding pass. 137 absl::flat_hash_set<const HloComputation*> 138 computations_visited_in_backward_pass_; 139 140 // *************************** 141 // Functions called by the final inconsistency resolving pass. 142 143 // Adjusts the output shapes of HloInstructions such that if two 144 // HloInstructions have aliasing buffers in their outputs, they must have the 145 // same precision. 146 void ResolveInconsistencyOfAliasingBuffers( 147 HloModule* module, 148 const absl::flat_hash_set<absl::string_view>& execution_threads); 149 150 // Resolves inconsistency of aliasing buffers for the given computation, and 151 // recursively runs on a while instruction's condition and body until a fixed 152 // point is reached. 153 bool ResolveInconsistencyOfAliasingBuffersHelper( 154 HloComputation* computation, 155 absl::flat_hash_set<const HloComputation*>* visited_computations); 156 157 // Makes the parameters of called computations match how they are called by 158 // the given HLO. 159 void AdjustCalledComputationParameters(HloInstruction* hlo); 160 161 // Makes the root instructions of called computations match how they are used 162 // by the given HLO. 163 void AdjustCalledComputationRoot(HloInstruction* hlo); 164 165 // *************************** 166 // Functions called after changes in changes_to_bf16_ are applied. 167 168 // Resolves inconsistencies introduced by this pass for fusions with 169 // tuple-type output. 170 Status ResolveInconsistentFusions( 171 HloModule* module, 172 const absl::flat_hash_set<absl::string_view>& execution_threads); 173 174 // Converts the literals in kConstant HLOs which have their types changed to 175 // BF16 by this pass. 176 Status ResolveConvertedConstants( 177 HloModule* module, 178 const absl::flat_hash_set<absl::string_view>& execution_threads); 179 180 // Skips no-op conversions (same source and target shapes) that can be 181 // produced this pass, i.e., replaces them in their uses with their operands. 182 Status SkipNoopConversions( 183 HloModule* module, 184 const absl::flat_hash_set<absl::string_view>& execution_threads); 185 186 // *************************** 187 // Functions called and state used by two or more passes. 188 189 // Returns whether all uses of the given HloInstruction can consume BF16 190 // input. 191 bool AllUsersConsumeBF16(const HloInstruction& hlo, 192 const ShapeIndex& index) const; 193 194 // The output element type of the HLO at the given shape index after changes 195 // in changes_to_bf16_ are applied. 196 PrimitiveType OutputTypeAfterChange(HloInstruction* hlo, 197 const ShapeIndex& index) const; 198 199 // The element type of the HLO value after changes in changes_to_bf16_ are 200 // applied. 201 PrimitiveType ValueTypeAfterChange(const HloValue* value) const; 202 203 // If target_type == BF16, adds the HLO at the given index to 204 // changes_to_bf16_; otherwise, target_type must be F32 and this function 205 // removes the HLO at the given index from changes_to_bf16_ if it was earlier 206 // added. 207 void AddToOrRemoveFromBF16ChangeSet(HloInstruction* hlo, 208 const ShapeIndex& index, 209 PrimitiveType target_type); 210 211 // The set of F32 HLO values that must be kept in F32. 212 absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_; 213 214 // Mapping from each HloComputation to the number of callers to it in the 215 // module. Populated at the beginning of this pass. 216 absl::flat_hash_map<const HloComputation*, int64_t> caller_counts_; 217 218 // We first store the potential F32-to-BF16 changes to changes_to_bf16_, which 219 // are subject to further adjustment, then finally applied to the HLOs. This 220 // avoids setting changed_ to true but all changes are reverted during 221 // adjustment. 222 // 223 // For each HloInstruction, changes_to_bf16_ stores the affected buffers in 224 // the output as a map from in-place pointers to subshapes to shape indices. 225 absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>> 226 changes_to_bf16_; 227 228 // Whether the last processed HLO module has been changed by this pass. 229 bool changed_ = false; 230 231 const BFloat16Support* bfloat16_support_; 232 std::unique_ptr<HloDataflowAnalysis> dataflow_; 233 }; 234 235 } // namespace xla 236 237 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ 238