1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h"
17
18 #include <algorithm>
19 #include <string>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
25 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
26 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
27 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32
33 namespace xla {
34 namespace gpu {
35
36 namespace {
37
38 // Traverses users of tuple shape, adding leaf instructions to 'instructions'.
MaybeResolveTupleElements(HloInstruction * instruction,std::vector<HloInstruction * > * instructions)39 void MaybeResolveTupleElements(HloInstruction* instruction,
40 std::vector<HloInstruction*>* instructions) {
41 if (instruction->shape().IsTuple()) {
42 for (auto tuple_user : instruction->users()) {
43 MaybeResolveTupleElements(tuple_user, instructions);
44 }
45 } else {
46 instructions->push_back(instruction);
47 }
48 }
49
50 // Returns the bytes read by fusion parameter 'param', by returning the byte
51 // size of 'param' shape (or the cumulative byte sizes of all leaf tuple
52 // elements if 'param' is tuple-shaped).
53 //
54 // In the special case where all users of 'param' (or all users of a leaf
55 // tuple element if 'param' is tuple-shaped) are Slice instructions, the size
56 // of each slice instruction is accumulated instead, to give a more accurate
57 // value for bytes read.
CalculateBytesReadByFusionParameter(HloInstruction * param)58 double CalculateBytesReadByFusionParameter(HloInstruction* param) {
59 CHECK_EQ(HloOpcode::kParameter, param->opcode());
60
61 // Adds all leaf tuple elements to 'instructions' if 'param' is tuple-shaped.
62 // Adds 'param' to 'instructions' otherwise.
63 std::vector<HloInstruction*> instructions;
64 MaybeResolveTupleElements(param, &instructions);
65
66 // Iterate through 'instructions' accumulating byte sizes of each instruction
67 // shape. For each 'instruction' in 'instructions', if all users of
68 // 'instruction' are Slice instructions, accumulates the byte sizes of each
69 // Slice for a more accurate estimate of bytes read.
70 double bytes = 0.0;
71 for (auto& instruction : instructions) {
72 if (absl::c_all_of(
73 instruction->users(), [](const HloInstruction* instruction) {
74 return instruction->opcode() == HloOpcode::kSlice ||
75 instruction->opcode() == HloOpcode::kDynamicSlice;
76 })) {
77 // All users are slice: accumulate bytes of all user slice instructions.
78 for (auto& user : instruction->users()) {
79 bytes += ShapeUtil::ByteSizeOf(user->shape());
80 }
81 } else {
82 // Some users are not slice: accumulate full size of 'instruction'.
83 bytes += ShapeUtil::ByteSizeOf(instruction->shape());
84 }
85 }
86 return bytes;
87 }
88
89 // Returns the bytes read by all fusion parameters of instruction 'fusion'.
CalculateBytesReadByFusionInstruction(HloInstruction * fusion)90 double CalculateBytesReadByFusionInstruction(HloInstruction* fusion) {
91 double bytes = 0.0;
92 for (auto* fused_instruction : fusion->fused_instructions()) {
93 if (fused_instruction->opcode() != HloOpcode::kParameter) {
94 continue;
95 }
96 bytes += CalculateBytesReadByFusionParameter(fused_instruction);
97 }
98 return bytes;
99 }
100
101 // Returns bytes transferred by instruction 'fusion', including the bytes
102 // that would be read by all users.
GetCurrentBytesTransferred(HloInstruction * fusion)103 double GetCurrentBytesTransferred(HloInstruction* fusion) {
104 CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
105 const double bytes_read = CalculateBytesReadByFusionInstruction(fusion);
106 double bytes_written = 0;
107 if (fusion->IsMultiOutputFusion()) {
108 for (auto& operand : fusion->fused_expression_root()->operands()) {
109 bytes_written += ShapeUtil::ByteSizeOf(operand->shape());
110 }
111 } else {
112 bytes_written =
113 ShapeUtil::ByteSizeOf(fusion->fused_expression_root()->shape());
114 }
115 // Current bytes transferred (ignoring non 'fusion' user operands) is bytes
116 // read and written by 'fusion', plus reads of size 'bytes_written' for each
117 // user.
118 return bytes_read + bytes_written * (fusion->user_count() + 1);
119 }
120
121 // Returns bytes transferred if 'fusion' were to be merged into its users.
GetMergedBytesTransferred(HloInstruction * fusion)122 double GetMergedBytesTransferred(HloInstruction* fusion) {
123 CHECK_EQ(HloOpcode::kFusion, fusion->opcode());
124 return CalculateBytesReadByFusionInstruction(fusion) * fusion->user_count();
125 }
126
127 } // anonymous namespace
128
129 // For each fusion F, attempts to fuse F into *all* of F's users (does not fuse
130 // if can't fuse into at least one).
131 class FusionInstructionMerger {
132 public:
FusionInstructionMerger(HloComputation * computation)133 explicit FusionInstructionMerger(HloComputation* computation)
134 : computation_(computation),
135 dump_fusion_visualization_(computation->parent()
136 ->config()
137 .debug_options()
138 .xla_dump_fusion_visualization()) {}
139
140 Status Run();
141
changed() const142 bool changed() const { return changed_; }
143
144 private:
145 FusionDecision HandleFusion(HloInstruction* fusion);
146 Status FuseIntoAllUsers(HloInstruction* instruction);
147
148 HloComputation* computation_;
149
150 bool changed_ = false;
151 bool dump_fusion_visualization_ = false;
152
153 // Fusion instruction merge stats.
154 int total_visited_ = 0;
155 int total_merged_ = 0;
156 int num_fail_no_users_ = 0;
157 int num_fail_not_loop_fusion_ = 0;
158 int num_fail_merge_all_users_ = 0;
159 int num_fail_expensive_fused_instruction_ = 0;
160 int num_fail_net_bytes_transferred_ratio_ = 0;
161 int num_fail_inefficient_fusion_emitter_ = 0;
162 int num_fail_fusion_too_large_ = 0;
163
164 FusionInstructionMerger(const FusionInstructionMerger&) = delete;
165 FusionInstructionMerger& operator=(const FusionInstructionMerger&) = delete;
166 };
167
FuseIntoAllUsers(HloInstruction * instruction)168 Status FusionInstructionMerger::FuseIntoAllUsers(HloInstruction* instruction) {
169 // Merge fused instructions from 'fusion' into each user.
170 std::vector<HloInstruction*> users = instruction->users();
171 for (HloInstruction* user : users) {
172 if (dump_fusion_visualization_) {
173 RegisterFusionState(
174 *computation_,
175 absl::StrCat("About to fuse |", instruction->name(), "| into |",
176 user->name(), "| inside FusionMerger"),
177 /*consumer=*/*user,
178 /*producer=*/instruction);
179 }
180
181 // Wrap consumers which are not fusions first.
182 HloInstruction* consumer = user;
183 if (consumer->opcode() != HloOpcode::kFusion) {
184 consumer = computation_->AddInstruction(HloInstruction::CreateFusion(
185 user->shape(), ChooseFusionKind(*instruction, *user), user));
186 TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer));
187 }
188 consumer->MergeFusionInstruction(instruction);
189 if (dump_fusion_visualization_) {
190 RegisterFusionState(
191 *computation_,
192 absl::StrCat("Fused |", instruction->name(), "| into |", user->name(),
193 "| inside FusionMerger"),
194 *consumer);
195 }
196 changed_ = true;
197 }
198
199 CHECK_EQ(0, instruction->user_count()) << instruction->ToString();
200 TF_RETURN_IF_ERROR(computation_->RemoveInstruction(instruction));
201 VLOG(2) << "Merged fusion instruction: " << instruction->name()
202 << " into users { "
203 << absl::StrJoin(users, ", ",
204 [](std::string* out, HloInstruction* user) {
205 absl::StrAppend(out, user->name());
206 })
207 << " }";
208 return OkStatus();
209 }
210
Run()211 Status FusionInstructionMerger::Run() {
212 for (HloInstruction* instruction : computation_->MakeInstructionPostOrder()) {
213 if (instruction->opcode() == HloOpcode::kFusion) {
214 FusionDecision was_fused = HandleFusion(instruction);
215 if (!was_fused) {
216 VLOG(2) << "Not fusing fusion |" << instruction->name()
217 << "| with all of it's users due to: " << was_fused.Explain();
218 if (dump_fusion_visualization_ && !instruction->users().empty()) {
219 RegisterFusionState(
220 *computation_,
221 absl::StrCat(
222 "Not fusing fusion |", instruction->name(),
223 "| into all of its users due to: ", was_fused.Explain()),
224 // Just pick any consumer, since we are trying to merge into all.
225 /*consumer=*/*instruction->users()[0],
226 /*producer=*/instruction);
227 }
228 } else {
229 TF_RETURN_IF_ERROR(FuseIntoAllUsers(instruction));
230 }
231 }
232 }
233
234 VLOG(1) << "FusionInstructionMerger EXIT"
235 << " computation: " << computation_->name()
236 << " total_visited: " << total_visited_
237 << " total_merged: " << total_merged_ << " merge failures { "
238 << " no_users: " << num_fail_no_users_
239 << " not_loop_fusion: " << num_fail_not_loop_fusion_
240 << " merge_all_users: " << num_fail_merge_all_users_
241 << " expensive_instruction: " << num_fail_expensive_fused_instruction_
242 << " net_bytes_transferred: " << num_fail_net_bytes_transferred_ratio_
243 << " inefficient_fusion_emitter: "
244 << num_fail_inefficient_fusion_emitter_
245 << " fusion_too_large: " << num_fail_fusion_too_large_ << " }";
246 return OkStatus();
247 }
248
HandleFusion(HloInstruction * fusion)249 FusionDecision FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
250 ++total_visited_;
251
252 // Skip 'fusion' instruction if there are no users into which we can merge.
253 if (fusion->users().empty()) {
254 ++num_fail_no_users_;
255 return "fusion has no users";
256 }
257
258 // Skip 'fusion' instruction if it is not a loop fusion. Library fusion
259 // instructions match specific patterns, so they shouldn't be further fused.
260 // Input fusion instructions need to be rooted at a particular HLO (e.g.
261 // kReduce), so they shouldn't be further fused either.
262 if (!fusion->IsLoopFusion()) {
263 ++num_fail_not_loop_fusion_;
264 return "not a loop fusion";
265 }
266
267 for (const HloInstruction* user : fusion->users()) {
268 FusionDecision fusible = IsProducerConsumerFusible(*fusion, *user)
269 .And({user->opcode() != HloOpcode::kBitcast,
270 "not fusing bitcast ops"});
271 if (!fusible) {
272 ++num_fail_merge_all_users_;
273 return fusible;
274 }
275 }
276
277 // Skip 'fusion' instruction if merging it into all users would result in a
278 // net increase in bytes transferred (currently allowing the net bytes
279 // transferred to be exceeded up to ~10% in exchange for eliminating the
280 // overhead from a GPU kernel launch).
281 const double current_bytes_transferred = GetCurrentBytesTransferred(fusion);
282 const double merged_bytes_transferred = GetMergedBytesTransferred(fusion);
283 const double merged_to_current_bytes_ratio =
284 merged_bytes_transferred / std::max(1.0, current_bytes_transferred);
285 if (merged_to_current_bytes_ratio > 1.10) {
286 ++num_fail_net_bytes_transferred_ratio_;
287 return FusionDecision{} << "merged-to-current-bytes-ratio of "
288 << merged_to_current_bytes_ratio
289 << " is not favorable";
290 }
291
292 // Skip 'fusion' instruction if any of its fused instructions are expensive.
293 // This is done to avoid the duplication of expensive instructions, which
294 // would occur if 'fusion' were merged into multiple users.
295 //
296 // Also, we don't want to fuse expensive instructions with instructions which
297 // reuse its operand values (e.g. Broadcast instructions).
298 //
299 // However, if we are going to save a "lot" in memory bandwidth then we
300 // ignore how expensive the fusion instructions are. The heuristic used to
301 // determine "a lot" is the following: merging must reduce memory traffic by a
302 // factor of 0.3, and the amount of memory accessed must not be entirely
303 // trivial (above 1K). This likely has room for improvement in the future.
304
305 bool allow_expensive_ops =
306 (fusion->user_count() == 1 || (merged_to_current_bytes_ratio < 0.3 &&
307 current_bytes_transferred > 1024)) &&
308 !absl::c_any_of(fusion->users(), [fusion](const HloInstruction* user) {
309 int64_t operand_index = user->operand_index(fusion);
310 return user->ReusesOperandElements(operand_index);
311 });
312 if (!allow_expensive_ops) {
313 for (const HloInstruction* instruction : fusion->fused_instructions()) {
314 if (instruction->opcode() != HloOpcode::kParameter &&
315 GpuInstructionFusion::IsExpensive(*instruction)) {
316 ++num_fail_expensive_fused_instruction_;
317 return FusionDecision{} << "fusion contains an expensive instruction |"
318 << instruction->name() << "|";
319 }
320 }
321 }
322
323 // Skip 'fusion' instruction if merging it into at least one of the users
324 // would cause too much code duplication because of inefficiencies in the
325 // fusion emitter.
326 // TODO(b/119692968): Remove this once the fusion emitter can handle arbitrary
327 // fusion nodes.
328 for (const HloInstruction* user : fusion->users()) {
329 if (FusedIrEmitter::IsFusedIrEmitterInefficient(/*consumer=*/*user,
330 /*producer=*/*fusion)) {
331 ++num_fail_inefficient_fusion_emitter_;
332 return FusionDecision{}
333 << "fusion contains user |" << user->ToShortString()
334 << "| which would cause inefficiency in fusion emitter";
335 }
336
337 // Skip 'fusion' instruction if merging it into at least one of the users
338 // would make the fusion too big.
339 FusionDecision fits = FusionFitsInBudget(*fusion, *user);
340 if (!fits) {
341 ++num_fail_fusion_too_large_;
342 return fits;
343 }
344 }
345
346 ++total_merged_;
347 return {};
348 }
349
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)350 StatusOr<bool> FusionMerger::Run(
351 HloModule* module,
352 const absl::flat_hash_set<absl::string_view>& execution_threads) {
353 bool changed = false;
354 VLOG(2) << "FusionMerger for module: " << module->name();
355 for (auto* computation :
356 module->MakeNonfusionComputations(execution_threads)) {
357 VLOG(1) << "Before running FusionInstructionMerger for computation: "
358 << computation->name();
359 XLA_VLOG_LINES(3, computation->ToString());
360
361 FusionInstructionMerger fusion_merger(computation);
362 TF_RETURN_IF_ERROR(fusion_merger.Run());
363 changed |= fusion_merger.changed();
364
365 VLOG(1) << "After running FusionInstructionMerger for computation: "
366 << computation->name() << " changed: " << changed;
367 XLA_VLOG_LINES(3, computation->ToString());
368 }
369 return changed;
370 }
371
372 } // namespace gpu
373 } // namespace xla
374