xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/fusion_merger.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
17 
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
25 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
26 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace xla {
34 namespace gpu {
35 
36 namespace {
37 
38 // Traverses users of tuple shape, adding leaf instructions to 'instructions'.
MaybeResolveTupleElements(HloInstruction * instruction,std::vector<HloInstruction * > * instructions)39 void MaybeResolveTupleElements(HloInstruction* instruction,
40                                std::vector<HloInstruction*>* instructions) {
41   if (instruction->shape().IsTuple()) {
42     for (auto tuple_user : instruction->users()) {
43       MaybeResolveTupleElements(tuple_user, instructions);
44     }
45   } else {
46     instructions->push_back(instruction);
47   }
48 }
49 
50 // Returns the bytes read by fusion parameter 'param', by returning the byte
51 // size of 'param' shape (or the cumulative byte sizes of all leaf tuple
52 // elements if 'param' is tuple-shaped).
53 //
54 // In the special case where all users of 'param' (or all users of a leaf
55 // tuple element if 'param' is tuple-shaped) are Slice instructions, the size
56 // of each slice instruction is accumulated instead, to give a more accurate
57 // value for bytes read.
CalculateBytesReadByFusionParameter(HloInstruction * param)58 double CalculateBytesReadByFusionParameter(HloInstruction* param) {
59   CHECK_EQ(HloOpcode::kParameter, param->opcode());
60 
61   // Adds all leaf tuple elements to 'instructions' if 'param' is tuple-shaped.
62   // Adds 'param' to 'instructions' otherwise.
63   std::vector<HloInstruction*> instructions;
64   MaybeResolveTupleElements(param, &instructions);
65 
66   // Iterate through 'instructions' accumulating byte sizes of each instruction
67   // shape. For each 'instruction' in 'instructions', if all users of
68   // 'instruction' are Slice instructions, accumulates the byte sizes of each
69   // Slice for a more accurate estimate of bytes read.
70   double bytes = 0.0;
71   for (auto& instruction : instructions) {
72     if (absl::c_all_of(
73             instruction->users(), [](const HloInstruction* instruction) {
74               return instruction->opcode() == HloOpcode::kSlice ||
75                      instruction->opcode() == HloOpcode::kDynamicSlice;
76             })) {
77       // All users are slice: accumulate bytes of all user slice instructions.
78       for (auto& user : instruction->users()) {
79         bytes += ShapeUtil::ByteSizeOf(user->shape());
80       }
81     } else {
82       // Some users are not slice: accumulate full size of 'instruction'.
83       bytes += ShapeUtil::ByteSizeOf(instruction->shape());
84     }
85   }
86   return bytes;
87 }
88 
89 // Returns the bytes read by all fusion parameters of instruction 'fusion'.
CalculateBytesReadByFusionInstruction(HloInstruction * fusion)90 double CalculateBytesReadByFusionInstruction(HloInstruction* fusion) {
91   double bytes = 0.0;
92   for (auto* fused_instruction : fusion->fused_instructions()) {
93     if (fused_instruction->opcode() != HloOpcode::kParameter) {
94       continue;
95     }
96     bytes += CalculateBytesReadByFusionParameter(fused_instruction);
97   }
98   return bytes;
99 }
100 
101 // Returns bytes transferred by instruction 'fusion', including the bytes
102 // that would be read by all users.
GetCurrentBytesTransferred(HloInstruction * fusion)103 double GetCurrentBytesTransferred(HloInstruction* fusion) {
104   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
105   const double bytes_read = CalculateBytesReadByFusionInstruction(fusion);
106   double bytes_written = 0;
107   if (fusion->IsMultiOutputFusion()) {
108     for (auto& operand : fusion->fused_expression_root()->operands()) {
109       bytes_written += ShapeUtil::ByteSizeOf(operand->shape());
110     }
111   } else {
112     bytes_written =
113         ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
114   }
115   // Current bytes transferred (ignoring non 'fusion' user operands) is bytes
116   // read and written by 'fusion', plus reads of size 'bytes_written' for each
117   // user.
118   return bytes_read + bytes_written * (fusion->user_count() + 1);
119 }
120 
121 // Returns bytes transferred if 'fusion' were to be merged into its users.
GetMergedBytesTransferred(HloInstruction * fusion)122 double GetMergedBytesTransferred(HloInstruction* fusion) {
123   CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
124   return CalculateBytesReadByFusionInstruction(fusion) * fusion->user_count();
125 }
126 
127 }  // anonymous namespace
128 
129 // For each fusion F, attempts to fuse F into *all* of F's users (does not fuse
130 // if can't fuse into at least one).
131 class FusionInstructionMerger {
132  public:
FusionInstructionMerger(HloComputation * computation)133   explicit FusionInstructionMerger(HloComputation* computation)
134       : computation_(computation),
135         dump_fusion_visualization_(computation->parent()
136                                        ->config()
137                                        .debug_options()
138                                        .xla_dump_fusion_visualization()) {}
139 
140   Status Run();
141 
changed() const142   bool changed() const { return changed_; }
143 
144  private:
145   FusionDecision HandleFusion(HloInstruction* fusion);
146   Status FuseIntoAllUsers(HloInstruction* instruction);
147 
148   HloComputation* computation_;
149 
150   bool changed_ = false;
151   bool dump_fusion_visualization_ = false;
152 
153   // Fusion instruction merge stats.
154   int total_visited_ = 0;
155   int total_merged_ = 0;
156   int num_fail_no_users_ = 0;
157   int num_fail_not_loop_fusion_ = 0;
158   int num_fail_merge_all_users_ = 0;
159   int num_fail_expensive_fused_instruction_ = 0;
160   int num_fail_net_bytes_transferred_ratio_ = 0;
161   int num_fail_inefficient_fusion_emitter_ = 0;
162   int num_fail_fusion_too_large_ = 0;
163 
164   FusionInstructionMerger(const FusionInstructionMerger&) = delete;
165   FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete;
166 };
167 
FuseIntoAllUsers(HloInstruction * instruction)168 Status FusionInstructionMerger::FuseIntoAllUsers(HloInstruction* instruction) {
169   // Merge fused instructions from 'fusion' into each user.
170   std::vector<HloInstruction*> users = instruction->users();
171   for (HloInstruction* user : users) {
172     if (dump_fusion_visualization_) {
173       RegisterFusionState(
174           *computation_,
175           absl::StrCat("About to fuse |", instruction->name(), "| into |",
176                        user->name(), "| inside FusionMerger"),
177           /*consumer=*/*user,
178           /*producer=*/instruction);
179     }
180 
181     // Wrap consumers which are not fusions first.
182     HloInstruction* consumer = user;
183     if (consumer->opcode() != HloOpcode::kFusion) {
184       consumer = computation_->AddInstruction(HloInstruction::CreateFusion(
185           user->shape(), ChooseFusionKind(*instruction, *user), user));
186       TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer));
187     }
188     consumer->MergeFusionInstruction(instruction);
189     if (dump_fusion_visualization_) {
190       RegisterFusionState(
191           *computation_,
192           absl::StrCat("Fused |", instruction->name(), "| into |", user->name(),
193                        "| inside FusionMerger"),
194           *consumer);
195     }
196     changed_ = true;
197   }
198 
199   CHECK_EQ(0, instruction->user_count()) << instruction->ToString();
200   TF_RETURN_IF_ERROR(computation_->RemoveInstruction(instruction));
201   VLOG(2) << "Merged fusion instruction: " << instruction->name()
202           << " into users { "
203           << absl::StrJoin(users, ", ",
204                            [](std::string* out, HloInstruction* user) {
205                              absl::StrAppend(out, user->name());
206                            })
207           << " }";
208   return OkStatus();
209 }
210 
Run()211 Status FusionInstructionMerger::Run() {
212   for (HloInstruction* instruction : computation_->MakeInstructionPostOrder()) {
213     if (instruction->opcode() == HloOpcode::kFusion) {
214       FusionDecision was_fused = HandleFusion(instruction);
215       if (!was_fused) {
216         VLOG(2) << "Not fusing fusion |" << instruction->name()
217                 << "| with all of it's users due to: " << was_fused.Explain();
218         if (dump_fusion_visualization_ && !instruction->users().empty()) {
219           RegisterFusionState(
220               *computation_,
221               absl::StrCat(
222                   "Not fusing fusion |", instruction->name(),
223                   "| into all of its users due to: ", was_fused.Explain()),
224               // Just pick any consumer, since we are trying to merge into all.
225               /*consumer=*/*instruction->users()[0],
226               /*producer=*/instruction);
227         }
228       } else {
229         TF_RETURN_IF_ERROR(FuseIntoAllUsers(instruction));
230       }
231     }
232   }
233 
234   VLOG(1) << "FusionInstructionMerger EXIT"
235           << " computation: " << computation_->name()
236           << " total_visited: " << total_visited_
237           << " total_merged: " << total_merged_ << " merge failures { "
238           << " no_users: " << num_fail_no_users_
239           << " not_loop_fusion: " << num_fail_not_loop_fusion_
240           << " merge_all_users: " << num_fail_merge_all_users_
241           << " expensive_instruction: " << num_fail_expensive_fused_instruction_
242           << " net_bytes_transferred: " << num_fail_net_bytes_transferred_ratio_
243           << " inefficient_fusion_emitter: "
244           << num_fail_inefficient_fusion_emitter_
245           << " fusion_too_large: " << num_fail_fusion_too_large_ << " }";
246   return OkStatus();
247 }
248 
HandleFusion(HloInstruction * fusion)249 FusionDecision FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
250   ++total_visited_;
251 
252   // Skip 'fusion' instruction if there are no users into which we can merge.
253   if (fusion->users().empty()) {
254     ++num_fail_no_users_;
255     return "fusion has no users";
256   }
257 
258   // Skip 'fusion' instruction if it is not a loop fusion. Library fusion
259   // instructions match specific patterns, so they shouldn't be further fused.
260   // Input fusion instructions need to be rooted at a particular HLO (e.g.
261   // kReduce), so they shouldn't be further fused either.
262   if (!fusion->IsLoopFusion()) {
263     ++num_fail_not_loop_fusion_;
264     return "not a loop fusion";
265   }
266 
267   for (const HloInstruction* user : fusion->users()) {
268     FusionDecision fusible = IsProducerConsumerFusible(*fusion, *user)
269                                  .And({user->opcode() != HloOpcode::kBitcast,
270                                        "not fusing bitcast ops"});
271     if (!fusible) {
272       ++num_fail_merge_all_users_;
273       return fusible;
274     }
275   }
276 
277   // Skip 'fusion' instruction if merging it into all users would result in a
278   // net increase in bytes transferred (currently allowing the net bytes
279   // transferred to be exceeded up to ~10% in exchange for eliminating the
280   // overhead from a GPU kernel launch).
281   const double current_bytes_transferred = GetCurrentBytesTransferred(fusion);
282   const double merged_bytes_transferred = GetMergedBytesTransferred(fusion);
283   const double merged_to_current_bytes_ratio =
284       merged_bytes_transferred / std::max(1.0, current_bytes_transferred);
285   if (merged_to_current_bytes_ratio > 1.10) {
286     ++num_fail_net_bytes_transferred_ratio_;
287     return FusionDecision{} << "merged-to-current-bytes-ratio of "
288                             << merged_to_current_bytes_ratio
289                             << " is not favorable";
290   }
291 
292   // Skip 'fusion' instruction if any of its fused instructions are expensive.
293   // This is done to avoid the duplication of expensive instructions, which
294   // would occur if 'fusion' were merged into multiple users.
295   //
296   // Also, we don't want to fuse expensive instructions with instructions which
297   // reuse its operand values (e.g. Broadcast instructions).
298   //
299   // However, if we are going to save a "lot" in memory bandwidth then we
300   // ignore how expensive the fusion instructions are.  The heuristic used to
301   // determine "a lot" is the following: merging must reduce memory traffic by a
302   // factor of 0.3, and the amount of memory accessed must not be entirely
303   // trivial (above 1K).  This likely has room for improvement in the future.
304 
305   bool allow_expensive_ops =
306       (fusion->user_count() == 1 || (merged_to_current_bytes_ratio < 0.3 &&
307                                      current_bytes_transferred > 1024)) &&
308       !absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) {
309         int64_t operand_index = user->operand_index(fusion);
310         return user->ReusesOperandElements(operand_index);
311       });
312   if (!allow_expensive_ops) {
313     for (const HloInstruction* instruction : fusion->fused_instructions()) {
314       if (instruction->opcode() != HloOpcode::kParameter &&
315           GpuInstructionFusion::IsExpensive(*instruction)) {
316         ++num_fail_expensive_fused_instruction_;
317         return FusionDecision{} << "fusion contains an expensive instruction |"
318                                 << instruction->name() << "|";
319       }
320     }
321   }
322 
323   // Skip 'fusion' instruction if merging it into at least one of the users
324   // would cause too much code duplication because of inefficiencies in the
325   // fusion emitter.
326   // TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary
327   // fusion nodes.
328   for (const HloInstruction* user : fusion->users()) {
329     if (FusedIrEmitter::IsFusedIrEmitterInefficient(/*consumer=*/*user,
330                                                     /*producer=*/*fusion)) {
331       ++num_fail_inefficient_fusion_emitter_;
332       return FusionDecision{}
333              << "fusion contains user |" << user->ToShortString()
334              << "| which would cause inefficiency in fusion emitter";
335     }
336 
337     // Skip 'fusion' instruction if merging it into at least one of the users
338     // would make the fusion too big.
339     FusionDecision fits = FusionFitsInBudget(*fusion, *user);
340     if (!fits) {
341       ++num_fail_fusion_too_large_;
342       return fits;
343     }
344   }
345 
346   ++total_merged_;
347   return {};
348 }
349 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)350 StatusOr<bool> FusionMerger::Run(
351     HloModule* module,
352     const absl::flat_hash_set<absl::string_view>& execution_threads) {
353   bool changed = false;
354   VLOG(2) << "FusionMerger for module: " << module->name();
355   for (auto* computation :
356        module->MakeNonfusionComputations(execution_threads)) {
357     VLOG(1) << "Before running FusionInstructionMerger for computation: "
358             << computation->name();
359     XLA_VLOG_LINES(3, computation->ToString());
360 
361     FusionInstructionMerger fusion_merger(computation);
362     TF_RETURN_IF_ERROR(fusion_merger.Run());
363     changed |= fusion_merger.changed();
364 
365     VLOG(1) << "After running FusionInstructionMerger for computation: "
366             << computation->name() << " changed: " << changed;
367     XLA_VLOG_LINES(3, computation->ToString());
368   }
369   return changed;
370 }
371 
372 }  // namespace gpu
373 }  // namespace xla
374