1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_dce.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/tuple_util.h"
32
33 namespace xla {
34 namespace {
35
36 // Create a constant (recursively for tuples) of the given shape and add it to
37 // the computation.
CreateConstant(const Shape & shape,HloComputation * computation)38 HloInstruction* CreateConstant(const Shape& shape,
39 HloComputation* computation) {
40 if (shape.IsTuple()) {
41 std::vector<HloInstruction*> tuple_arguments(shape.tuple_shapes_size());
42 for (int index = 0; index < shape.tuple_shapes_size(); ++index) {
43 tuple_arguments[index] =
44 CreateConstant(shape.tuple_shapes(index), computation);
45 }
46 return computation->AddInstruction(
47 HloInstruction::CreateTuple(tuple_arguments));
48 } else {
49 return computation->AddInstruction(
50 HloInstruction::CreateConstant(Literal::CreateFromShape(shape)));
51 }
52 }
53
54 // Prints sub-expression rooted at inst for a given depth.
PrintSubexpression(HloInstruction * inst,int depth)55 void PrintSubexpression(HloInstruction* inst, int depth) {
56 if (depth == 0) {
57 return;
58 }
59 for (auto* operand : inst->operands()) {
60 PrintSubexpression(operand, depth - 1);
61 }
62 VLOG(2) << inst->ToString();
63 }
64
IsConstantScalarInt(const HloInstruction * inst)65 bool IsConstantScalarInt(const HloInstruction* inst) {
66 return inst->opcode() == HloOpcode::kConstant &&
67 ShapeUtil::IsEffectiveScalar(inst->shape()) &&
68 inst->shape().IsInteger();
69 }
70
IsNotContainedInLoop(const HloInstruction & while_hlo,const CallGraph & call_graph)71 bool IsNotContainedInLoop(const HloInstruction& while_hlo,
72 const CallGraph& call_graph) {
73 const HloComputation* computation = while_hlo.parent();
74 while (!computation->IsEntryComputation()) {
75 auto& node = call_graph.GetNode(computation);
76 CHECK_EQ(node.caller_callsites().size(), 1)
77 << "The module is not flattened!";
78 auto& callsite = node.caller_callsites()[0];
79 if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
80 // Another while loop has been found traversing up the call tree.
81 return false;
82 }
83 computation = callsite.instruction()->parent();
84 }
85 // No calling while loops were found.
86 return true;
87 }
88
89 } // namespace
90
GetLoopBound(const HloInstruction & while_hlo,const int default_loop_count,const int max_loop_count)91 int GetLoopBound(const HloInstruction& while_hlo, const int default_loop_count,
92 const int max_loop_count) {
93 HloInstruction* condition = while_hlo.while_condition()->root_instruction();
94 if (condition->opcode() == HloOpcode::kCompare) {
95 int64_t value = 0;
96 Comparison::Direction cmp = condition->comparison_direction();
97 if ((cmp == Comparison::Direction::kLt ||
98 cmp == Comparison::Direction::kLe ||
99 cmp == Comparison::Direction::kNe) &&
100 IsConstantScalarInt(condition->operand(1))) {
101 value = *condition->operand(1)->literal().GetFirstInteger();
102 } else if ((cmp == Comparison::Direction::kGt ||
103 cmp == Comparison::Direction::kGe ||
104 cmp == Comparison::Direction::kNe) &&
105 IsConstantScalarInt(condition->operand(0))) {
106 value = *condition->operand(0)->literal().GetFirstInteger();
107 }
108 if (value > 0) {
109 // Caps to a max loop count to avoid long execution times.
110 return std::min(value, static_cast<int64_t>(max_loop_count));
111 }
112 }
113 return default_loop_count;
114 }
115
GetLoopBoundWithOuterLoopMax(const HloInstruction & while_hlo,const CallGraph & call_graph,const int default_loop_count,const int max_outer_loop_count,const int max_loop_count)116 int GetLoopBoundWithOuterLoopMax(const HloInstruction& while_hlo,
117 const CallGraph& call_graph,
118 const int default_loop_count,
119 const int max_outer_loop_count,
120 const int max_loop_count) {
121 int loop_bound = GetLoopBound(while_hlo, default_loop_count, max_loop_count);
122 if (loop_bound > max_outer_loop_count) {
123 // First does the inexpensive loop bound check to avoid as many
124 // expensive graph traversals in IsNotContainedInLoop as possible.
125 if (IsNotContainedInLoop(while_hlo, call_graph)) {
126 return max_outer_loop_count;
127 }
128 }
129 return loop_bound;
130 }
131
FlattenWhileLoop(HloInstruction * while_hlo,const CallGraph & call_graph) const132 Status HloControlFlowFlattening::FlattenWhileLoop(
133 HloInstruction* while_hlo, const CallGraph& call_graph) const {
134 CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
135 HloComputation* computation = while_hlo->parent();
136 // Add a new induction variable.
137 HloInstruction* initialization = computation->AddInstruction(
138 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
139 // Create a new while operand with the induction variable added.
140 HloInstruction* old_tuple = while_hlo->mutable_operand(0);
141 HloInstruction* new_tuple =
142 TupleUtil::AppendSuffix(old_tuple, {initialization});
143 int new_tuple_size = new_tuple->shape().tuple_shapes().size();
144 TF_RETURN_IF_ERROR(while_hlo->ReplaceOperandWithDifferentShape(0, new_tuple));
145
146 auto change_op_shape = [&](HloInstruction* instruction) {
147 Shape* shape = instruction->mutable_shape();
148 CHECK(shape->IsTuple());
149 CHECK_EQ(shape->tuple_shapes().size(), new_tuple_size - 1);
150 Shape* subshape = shape->add_tuple_shapes();
151 return ShapeUtil::PopulateShape(S32, {}, subshape);
152 };
153
154 // Replace the given tuple-shaped instruction of size N in each of its
155 // non-get-tuple-element users with a new tuple instruction which has the
156 // first N - 1 elements.
157 auto replace_non_gte_users =
158 [](HloInstruction* new_tuple) -> StatusOr<HloInstruction*> {
159 CHECK(new_tuple->shape().IsTuple());
160 HloInstruction* prefix = nullptr;
161 std::vector<HloInstruction*> users(new_tuple->users());
162 for (HloInstruction* user : users) {
163 if (user->opcode() == HloOpcode::kGetTupleElement) {
164 continue;
165 }
166 // Lazily extract the prefix on demand, reuse it as needed.
167 if (prefix == nullptr) {
168 prefix = TupleUtil::ExtractPrefix(
169 new_tuple, new_tuple->shape().tuple_shapes_size() - 1);
170 }
171 TF_RETURN_IF_ERROR(new_tuple->ReplaceUseWithDifferentShape(user, prefix));
172 }
173 return prefix;
174 };
175
176 {
177 // Add the new variable to the while loop condition.
178 HloComputation* condition = while_hlo->while_condition();
179 TF_RETURN_IF_ERROR(change_op_shape(condition->parameter_instruction(0)));
180 TF_RETURN_IF_ERROR(
181 replace_non_gte_users(condition->parameter_instruction(0)).status());
182 if (VLOG_IS_ON(2)) {
183 VLOG(2) << "Loop condition in " << while_hlo->parent()->name();
184 PrintSubexpression(condition->root_instruction(), /*depth=*/3);
185 }
186 const int loop_bound = GetLoopBoundWithOuterLoopMax(
187 *while_hlo, call_graph, while_execution_count_, max_outer_loop_count_,
188 max_loop_count_);
189
190 VLOG(1) << "loop_bound = " << loop_bound;
191
192 HloInstruction* limit = condition->AddInstruction(
193 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(loop_bound)));
194 Shape shape = initialization->shape();
195 HloInstruction* induction_variable =
196 condition->AddInstruction(HloInstruction::CreateGetTupleElement(
197 shape, condition->parameter_instruction(0), new_tuple_size - 1));
198 HloInstruction* compare =
199 condition->AddInstruction(HloInstruction::CreateCompare(
200 ShapeUtil::MakeShape(PRED, {}), induction_variable, limit,
201 ComparisonDirection::kLt));
202 TF_RETURN_IF_ERROR(
203 condition->ReplaceInstruction(condition->root_instruction(), compare));
204 }
205
206 {
207 // Add the new variable to the while loop body.
208 HloComputation* body = while_hlo->while_body();
209 TF_RETURN_IF_ERROR(change_op_shape(body->parameter_instruction(0)));
210 TF_RETURN_IF_ERROR(
211 replace_non_gte_users(body->parameter_instruction(0)).status());
212 HloInstruction* old_root = body->root_instruction();
213 Shape shape = initialization->shape();
214 HloInstruction* induction_variable =
215 body->AddInstruction(HloInstruction::CreateGetTupleElement(
216 shape, body->parameter_instruction(0), new_tuple_size - 1));
217 HloInstruction* increment = body->AddInstruction(
218 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
219 induction_variable = body->AddInstruction(HloInstruction::CreateBinary(
220 shape, HloOpcode::kAdd, induction_variable, increment));
221 HloInstruction* new_root =
222 TupleUtil::AppendSuffix(old_root, {induction_variable});
223 body->set_root_instruction(new_root, /*accept_different_shape=*/true);
224 }
225
226 // Snapshot the users of while hlo before we add new users.
227 std::vector<HloInstruction*> while_users(while_hlo->users().begin(),
228 while_hlo->users().end());
229
230 // Take care of the users of this while loop.
231 TF_RETURN_IF_ERROR(change_op_shape(while_hlo));
232 TF_ASSIGN_OR_RETURN(HloInstruction * prefix,
233 replace_non_gte_users(while_hlo));
234
235 // If the while loop had been the root of its computation, make the prefix new
236 // root.
237 if (while_hlo->parent()->root_instruction() == while_hlo) {
238 // We need to set accept_different_shape=true to reset the root shape to the
239 // original, because we have already changed the shape of the old root
240 // (while).
241 if (prefix == nullptr) {
242 prefix = TupleUtil::ExtractPrefix(while_hlo, new_tuple_size - 1);
243 }
244 while_hlo->parent()->set_root_instruction(prefix,
245 /*accept_different_shape=*/true);
246 }
247
248 return OkStatus();
249 }
250
RemoveInfeed(HloInstruction * infeed_hlo) const251 Status HloControlFlowFlattening::RemoveInfeed(
252 HloInstruction* infeed_hlo) const {
253 CHECK_EQ(infeed_hlo->opcode(), HloOpcode::kInfeed);
254 HloComputation* computation = infeed_hlo->parent();
255 CHECK_EQ(infeed_hlo->shape().tuple_shapes_size(), 2);
256 const Shape& infeed_shape = ShapeUtil::GetSubshape(infeed_hlo->shape(), {0});
257
258 HloInstruction* custom_call = computation->AddInstruction(
259 HloInstruction::CreateCustomCall(infeed_shape, {}, kNopCustomCallTarget));
260
261 // Create a new tuple consisting op the constant and the token that was
262 // originally the operand of infeed, and replace the infeed operation.
263 auto new_tuple = HloInstruction::CreateTuple(
264 {custom_call, infeed_hlo->mutable_operand(0)});
265 TF_RETURN_IF_ERROR(
266 computation->ReplaceWithNewInstruction(infeed_hlo, std::move(new_tuple)));
267
268 return OkStatus();
269 }
270
RemoveRecvDone(HloInstruction * recv_done,absl::flat_hash_set<HloInstruction * > * additional_removed) const271 Status HloControlFlowFlattening::RemoveRecvDone(
272 HloInstruction* recv_done,
273 absl::flat_hash_set<HloInstruction*>* additional_removed) const {
274 CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
275 CHECK_EQ(recv_done->operand_count(), 1);
276 HloInstruction* recv = recv_done->mutable_operand(0);
277 CHECK_EQ(recv->opcode(), HloOpcode::kRecv);
278
279 HloComputation* computation = recv_done->parent();
280 CHECK_EQ(recv_done->shape().tuple_shapes_size(), 2);
281 const Shape& recv_shape = ShapeUtil::GetSubshape(recv_done->shape(), {0});
282
283 HloInstruction* custom_call = computation->AddInstruction(
284 HloInstruction::CreateCustomCall(recv_shape, {}, kNopCustomCallTarget));
285
286 // Create a new tuple consisting op the constant and the token that was
287 // originally the operand of recv, and replace the recv operation.
288 auto new_tuple =
289 HloInstruction::CreateTuple({custom_call, recv->mutable_operand(0)});
290 TF_RETURN_IF_ERROR(
291 computation->ReplaceWithNewInstruction(recv_done, std::move(new_tuple)));
292 additional_removed->insert(recv);
293 TF_RETURN_IF_ERROR(computation->RemoveInstruction(recv));
294
295 return OkStatus();
296 }
297
RemoveOutfeed(HloInstruction * outfeed_hlo) const298 Status HloControlFlowFlattening::RemoveOutfeed(
299 HloInstruction* outfeed_hlo) const {
300 CHECK_EQ(outfeed_hlo->opcode(), HloOpcode::kOutfeed);
301 HloComputation* computation = outfeed_hlo->parent();
302 // Replace the outfeed with a no-op custom call with side effect to ensure the
303 // operands aren't DCE'd.
304 HloInstruction* custom_call =
305 computation->AddInstruction(HloInstruction::CreateCustomCall(
306 outfeed_hlo->shape(), outfeed_hlo->operands(), "NopReturnToken"));
307 Cast<HloCustomCallInstruction>(custom_call)
308 ->set_custom_call_has_side_effect(true);
309 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(outfeed_hlo, custom_call));
310
311 return OkStatus();
312 }
313
RemoveSendDone(HloInstruction * send_done,absl::flat_hash_set<HloInstruction * > * additional_removed) const314 Status HloControlFlowFlattening::RemoveSendDone(
315 HloInstruction* send_done,
316 absl::flat_hash_set<HloInstruction*>* additional_removed) const {
317 CHECK_EQ(send_done->opcode(), HloOpcode::kSendDone);
318 CHECK_EQ(send_done->operand_count(), 1);
319 HloInstruction* send = send_done->mutable_operand(0);
320 CHECK_EQ(send->opcode(), HloOpcode::kSend);
321
322 HloComputation* computation = send_done->parent();
323 HloInstruction* custom_call =
324 computation->AddInstruction(HloInstruction::CreateCustomCall(
325 send_done->shape(), send_done->operand(0)->operands(),
326 "NopReturnToken"));
327 Cast<HloCustomCallInstruction>(custom_call)
328 ->set_custom_call_has_side_effect(true);
329
330 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(send_done, custom_call));
331 additional_removed->insert(send);
332 TF_RETURN_IF_ERROR(computation->RemoveInstruction(send));
333
334 return OkStatus();
335 }
336
RemoveCollective(HloInstruction * hlo) const337 Status HloControlFlowFlattening::RemoveCollective(HloInstruction* hlo) const {
338 HloComputation* computation = hlo->parent();
339 HloInstruction* custom_call =
340 computation->AddInstruction(HloInstruction::CreateCustomCall(
341 hlo->shape(), hlo->operands(), kNopCustomCallTarget));
342 // Copy backend config. This is necessary for a collective op in megacore
343 // fusion.
344 custom_call->CopyBackendConfigFrom(hlo);
345 auto replaced_collective_op_str =
346 hlo->ToString(HloPrintOptions().Canonical());
347 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, custom_call));
348 custom_call->set_metadata_replaced_op(replaced_collective_op_str);
349 return OkStatus();
350 }
351
RemovePartitionOrReplicaId(HloInstruction * hlo) const352 Status HloControlFlowFlattening::RemovePartitionOrReplicaId(
353 HloInstruction* hlo) const {
354 HloComputation* computation = hlo->parent();
355 HloInstruction* zero = CreateConstant(hlo->shape(), computation);
356 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, zero));
357 return OkStatus();
358 }
359
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)360 StatusOr<bool> HloControlFlowFlattening::Run(
361 HloModule* module,
362 const absl::flat_hash_set<absl::string_view>& execution_threads) {
363 auto call_graph = CallGraph::Build(module);
364 bool changed = false;
365 absl::flat_hash_set<HloInstruction*> removed;
366 for (HloComputation* computation : module->computations(execution_threads)) {
367 for (HloInstruction* instruction :
368 computation->MakeInstructionPostOrder()) {
369 if (removed.contains(instruction)) {
370 // Skip the instruction if it is already removed.
371 continue;
372 }
373 if (flatten_while_loop_ && instruction->opcode() == HloOpcode::kWhile) {
374 VLOG(1) << "Remove " << instruction->name();
375 TF_RETURN_IF_ERROR(FlattenWhileLoop(instruction, *call_graph));
376 changed = true;
377 } else if (remove_infeed_outfeed_ &&
378 instruction->opcode() == HloOpcode::kInfeed) {
379 VLOG(1) << "Remove " << instruction->name();
380 TF_RETURN_IF_ERROR(RemoveInfeed(instruction));
381 changed = true;
382 } else if (remove_infeed_outfeed_ &&
383 instruction->opcode() == HloOpcode::kOutfeed) {
384 VLOG(1) << "Remove " << instruction->name();
385 TF_RETURN_IF_ERROR(RemoveOutfeed(instruction));
386 changed = true;
387 } else if (instruction->opcode() == HloOpcode::kSendDone) {
388 auto send_done_instruction =
389 DynCast<HloSendDoneInstruction>(instruction);
390 CHECK(send_done_instruction);
391 if (remove_comm_ || (remove_host_transfer_ &&
392 send_done_instruction->is_host_transfer())) {
393 VLOG(1) << "Remove " << instruction->name();
394 TF_RETURN_IF_ERROR(RemoveSendDone(instruction, &removed));
395 changed = true;
396 }
397 } else if (instruction->opcode() == HloOpcode::kRecvDone) {
398 auto recv_done_instruction =
399 DynCast<HloRecvDoneInstruction>(instruction);
400 CHECK(recv_done_instruction);
401 if (remove_comm_ || (remove_host_transfer_ &&
402 recv_done_instruction->is_host_transfer())) {
403 VLOG(1) << "Remove " << instruction->name();
404 TF_RETURN_IF_ERROR(RemoveRecvDone(instruction, &removed));
405 changed = true;
406 }
407 } else if (remove_comm_ && IsCollective(instruction) &&
408 !instruction->parent()->IsFusionComputation()) {
409 VLOG(1) << "Remove " << instruction->name();
410 TF_RETURN_IF_ERROR(RemoveCollective(instruction));
411 changed = true;
412 } else if (remove_comm_ &&
413 (instruction->opcode() == HloOpcode::kPartitionId ||
414 instruction->opcode() == HloOpcode::kReplicaId)) {
415 VLOG(1) << "Remove " << instruction->name();
416 TF_RETURN_IF_ERROR(RemovePartitionOrReplicaId(instruction));
417 }
418 }
419 }
420
421 HloDCE hlo_dce;
422 TF_ASSIGN_OR_RETURN(bool dce_changed, hlo_dce.Run(module, execution_threads));
423 changed |= dce_changed;
424
425 // Fix the schedule if the module was scheduled.
426 if (changed && module->has_schedule()) {
427 TF_RETURN_IF_ERROR(module->schedule().Update());
428 }
429 XLA_VLOG_LINES(3, module->ToString());
430 return changed;
431 }
432
433 } // namespace xla
434