xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tools/hlo_control_flow_flattening.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <string>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_dce.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/tuple_util.h"
32 
33 namespace xla {
34 namespace {
35 
36 // Create a constant (recursively for tuples) of the given shape and add it to
37 // the computation.
CreateConstant(const Shape & shape,HloComputation * computation)38 HloInstruction* CreateConstant(const Shape& shape,
39                                HloComputation* computation) {
40   if (shape.IsTuple()) {
41     std::vector<HloInstruction*> tuple_arguments(shape.tuple_shapes_size());
42     for (int index = 0; index < shape.tuple_shapes_size(); ++index) {
43       tuple_arguments[index] =
44           CreateConstant(shape.tuple_shapes(index), computation);
45     }
46     return computation->AddInstruction(
47         HloInstruction::CreateTuple(tuple_arguments));
48   } else {
49     return computation->AddInstruction(
50         HloInstruction::CreateConstant(Literal::CreateFromShape(shape)));
51   }
52 }
53 
54 // Prints sub-expression rooted at inst for a given depth.
PrintSubexpression(HloInstruction * inst,int depth)55 void PrintSubexpression(HloInstruction* inst, int depth) {
56   if (depth == 0) {
57     return;
58   }
59   for (auto* operand : inst->operands()) {
60     PrintSubexpression(operand, depth - 1);
61   }
62   VLOG(2) << inst->ToString();
63 }
64 
IsConstantScalarInt(const HloInstruction * inst)65 bool IsConstantScalarInt(const HloInstruction* inst) {
66   return inst->opcode() == HloOpcode::kConstant &&
67          ShapeUtil::IsEffectiveScalar(inst->shape()) &&
68          inst->shape().IsInteger();
69 }
70 
IsNotContainedInLoop(const HloInstruction & while_hlo,const CallGraph & call_graph)71 bool IsNotContainedInLoop(const HloInstruction& while_hlo,
72                           const CallGraph& call_graph) {
73   const HloComputation* computation = while_hlo.parent();
74   while (!computation->IsEntryComputation()) {
75     auto& node = call_graph.GetNode(computation);
76     CHECK_EQ(node.caller_callsites().size(), 1)
77         << "The module is not flattened!";
78     auto& callsite = node.caller_callsites()[0];
79     if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
80       // Another while loop has been found traversing up the call tree.
81       return false;
82     }
83     computation = callsite.instruction()->parent();
84   }
85   // No calling while loops were found.
86   return true;
87 }
88 
89 }  // namespace
90 
GetLoopBound(const HloInstruction & while_hlo,const int default_loop_count,const int max_loop_count)91 int GetLoopBound(const HloInstruction& while_hlo, const int default_loop_count,
92                  const int max_loop_count) {
93   HloInstruction* condition = while_hlo.while_condition()->root_instruction();
94   if (condition->opcode() == HloOpcode::kCompare) {
95     int64_t value = 0;
96     Comparison::Direction cmp = condition->comparison_direction();
97     if ((cmp == Comparison::Direction::kLt ||
98          cmp == Comparison::Direction::kLe ||
99          cmp == Comparison::Direction::kNe) &&
100         IsConstantScalarInt(condition->operand(1))) {
101       value = *condition->operand(1)->literal().GetFirstInteger();
102     } else if ((cmp == Comparison::Direction::kGt ||
103                 cmp == Comparison::Direction::kGe ||
104                 cmp == Comparison::Direction::kNe) &&
105                IsConstantScalarInt(condition->operand(0))) {
106       value = *condition->operand(0)->literal().GetFirstInteger();
107     }
108     if (value > 0) {
109       // Caps to a max loop count to avoid long execution times.
110       return std::min(value, static_cast<int64_t>(max_loop_count));
111     }
112   }
113   return default_loop_count;
114 }
115 
GetLoopBoundWithOuterLoopMax(const HloInstruction & while_hlo,const CallGraph & call_graph,const int default_loop_count,const int max_outer_loop_count,const int max_loop_count)116 int GetLoopBoundWithOuterLoopMax(const HloInstruction& while_hlo,
117                                  const CallGraph& call_graph,
118                                  const int default_loop_count,
119                                  const int max_outer_loop_count,
120                                  const int max_loop_count) {
121   int loop_bound = GetLoopBound(while_hlo, default_loop_count, max_loop_count);
122   if (loop_bound > max_outer_loop_count) {
123     // First does the inexpensive loop bound check to avoid as many
124     // expensive graph traversals in IsNotContainedInLoop as possible.
125     if (IsNotContainedInLoop(while_hlo, call_graph)) {
126       return max_outer_loop_count;
127     }
128   }
129   return loop_bound;
130 }
131 
FlattenWhileLoop(HloInstruction * while_hlo,const CallGraph & call_graph) const132 Status HloControlFlowFlattening::FlattenWhileLoop(
133     HloInstruction* while_hlo, const CallGraph& call_graph) const {
134   CHECK_EQ(while_hlo->opcode(), HloOpcode::kWhile);
135   HloComputation* computation = while_hlo->parent();
136   // Add a new induction variable.
137   HloInstruction* initialization = computation->AddInstruction(
138       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
139   // Create a new while operand with the induction variable added.
140   HloInstruction* old_tuple = while_hlo->mutable_operand(0);
141   HloInstruction* new_tuple =
142       TupleUtil::AppendSuffix(old_tuple, {initialization});
143   int new_tuple_size = new_tuple->shape().tuple_shapes().size();
144   TF_RETURN_IF_ERROR(while_hlo->ReplaceOperandWithDifferentShape(0, new_tuple));
145 
146   auto change_op_shape = [&](HloInstruction* instruction) {
147     Shape* shape = instruction->mutable_shape();
148     CHECK(shape->IsTuple());
149     CHECK_EQ(shape->tuple_shapes().size(), new_tuple_size - 1);
150     Shape* subshape = shape->add_tuple_shapes();
151     return ShapeUtil::PopulateShape(S32, {}, subshape);
152   };
153 
154   // Replace the given tuple-shaped instruction of size N in each of its
155   // non-get-tuple-element users with a new tuple instruction which has the
156   // first N - 1 elements.
157   auto replace_non_gte_users =
158       [](HloInstruction* new_tuple) -> StatusOr<HloInstruction*> {
159     CHECK(new_tuple->shape().IsTuple());
160     HloInstruction* prefix = nullptr;
161     std::vector<HloInstruction*> users(new_tuple->users());
162     for (HloInstruction* user : users) {
163       if (user->opcode() == HloOpcode::kGetTupleElement) {
164         continue;
165       }
166       // Lazily extract the prefix on demand, reuse it as needed.
167       if (prefix == nullptr) {
168         prefix = TupleUtil::ExtractPrefix(
169             new_tuple, new_tuple->shape().tuple_shapes_size() - 1);
170       }
171       TF_RETURN_IF_ERROR(new_tuple->ReplaceUseWithDifferentShape(user, prefix));
172     }
173     return prefix;
174   };
175 
176   {
177     // Add the new variable to the while loop condition.
178     HloComputation* condition = while_hlo->while_condition();
179     TF_RETURN_IF_ERROR(change_op_shape(condition->parameter_instruction(0)));
180     TF_RETURN_IF_ERROR(
181         replace_non_gte_users(condition->parameter_instruction(0)).status());
182     if (VLOG_IS_ON(2)) {
183       VLOG(2) << "Loop condition in " << while_hlo->parent()->name();
184       PrintSubexpression(condition->root_instruction(), /*depth=*/3);
185     }
186     const int loop_bound = GetLoopBoundWithOuterLoopMax(
187         *while_hlo, call_graph, while_execution_count_, max_outer_loop_count_,
188         max_loop_count_);
189 
190     VLOG(1) << "loop_bound = " << loop_bound;
191 
192     HloInstruction* limit = condition->AddInstruction(
193         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(loop_bound)));
194     Shape shape = initialization->shape();
195     HloInstruction* induction_variable =
196         condition->AddInstruction(HloInstruction::CreateGetTupleElement(
197             shape, condition->parameter_instruction(0), new_tuple_size - 1));
198     HloInstruction* compare =
199         condition->AddInstruction(HloInstruction::CreateCompare(
200             ShapeUtil::MakeShape(PRED, {}), induction_variable, limit,
201             ComparisonDirection::kLt));
202     TF_RETURN_IF_ERROR(
203         condition->ReplaceInstruction(condition->root_instruction(), compare));
204   }
205 
206   {
207     // Add the new variable to the while loop body.
208     HloComputation* body = while_hlo->while_body();
209     TF_RETURN_IF_ERROR(change_op_shape(body->parameter_instruction(0)));
210     TF_RETURN_IF_ERROR(
211         replace_non_gte_users(body->parameter_instruction(0)).status());
212     HloInstruction* old_root = body->root_instruction();
213     Shape shape = initialization->shape();
214     HloInstruction* induction_variable =
215         body->AddInstruction(HloInstruction::CreateGetTupleElement(
216             shape, body->parameter_instruction(0), new_tuple_size - 1));
217     HloInstruction* increment = body->AddInstruction(
218         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
219     induction_variable = body->AddInstruction(HloInstruction::CreateBinary(
220         shape, HloOpcode::kAdd, induction_variable, increment));
221     HloInstruction* new_root =
222         TupleUtil::AppendSuffix(old_root, {induction_variable});
223     body->set_root_instruction(new_root, /*accept_different_shape=*/true);
224   }
225 
226   // Snapshot the users of while hlo before we add new users.
227   std::vector<HloInstruction*> while_users(while_hlo->users().begin(),
228                                            while_hlo->users().end());
229 
230   // Take care of the users of this while loop.
231   TF_RETURN_IF_ERROR(change_op_shape(while_hlo));
232   TF_ASSIGN_OR_RETURN(HloInstruction * prefix,
233                       replace_non_gte_users(while_hlo));
234 
235   // If the while loop had been the root of its computation, make the prefix new
236   // root.
237   if (while_hlo->parent()->root_instruction() == while_hlo) {
238     // We need to set accept_different_shape=true to reset the root shape to the
239     // original, because we have already changed the shape of the old root
240     // (while).
241     if (prefix == nullptr) {
242       prefix = TupleUtil::ExtractPrefix(while_hlo, new_tuple_size - 1);
243     }
244     while_hlo->parent()->set_root_instruction(prefix,
245                                               /*accept_different_shape=*/true);
246   }
247 
248   return OkStatus();
249 }
250 
RemoveInfeed(HloInstruction * infeed_hlo) const251 Status HloControlFlowFlattening::RemoveInfeed(
252     HloInstruction* infeed_hlo) const {
253   CHECK_EQ(infeed_hlo->opcode(), HloOpcode::kInfeed);
254   HloComputation* computation = infeed_hlo->parent();
255   CHECK_EQ(infeed_hlo->shape().tuple_shapes_size(), 2);
256   const Shape& infeed_shape = ShapeUtil::GetSubshape(infeed_hlo->shape(), {0});
257 
258   HloInstruction* custom_call = computation->AddInstruction(
259       HloInstruction::CreateCustomCall(infeed_shape, {}, kNopCustomCallTarget));
260 
261   // Create a new tuple consisting op the constant and the token that was
262   // originally the operand of infeed, and replace the infeed operation.
263   auto new_tuple = HloInstruction::CreateTuple(
264       {custom_call, infeed_hlo->mutable_operand(0)});
265   TF_RETURN_IF_ERROR(
266       computation->ReplaceWithNewInstruction(infeed_hlo, std::move(new_tuple)));
267 
268   return OkStatus();
269 }
270 
RemoveRecvDone(HloInstruction * recv_done,absl::flat_hash_set<HloInstruction * > * additional_removed) const271 Status HloControlFlowFlattening::RemoveRecvDone(
272     HloInstruction* recv_done,
273     absl::flat_hash_set<HloInstruction*>* additional_removed) const {
274   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
275   CHECK_EQ(recv_done->operand_count(), 1);
276   HloInstruction* recv = recv_done->mutable_operand(0);
277   CHECK_EQ(recv->opcode(), HloOpcode::kRecv);
278 
279   HloComputation* computation = recv_done->parent();
280   CHECK_EQ(recv_done->shape().tuple_shapes_size(), 2);
281   const Shape& recv_shape = ShapeUtil::GetSubshape(recv_done->shape(), {0});
282 
283   HloInstruction* custom_call = computation->AddInstruction(
284       HloInstruction::CreateCustomCall(recv_shape, {}, kNopCustomCallTarget));
285 
286   // Create a new tuple consisting op the constant and the token that was
287   // originally the operand of recv, and replace the recv operation.
288   auto new_tuple =
289       HloInstruction::CreateTuple({custom_call, recv->mutable_operand(0)});
290   TF_RETURN_IF_ERROR(
291       computation->ReplaceWithNewInstruction(recv_done, std::move(new_tuple)));
292   additional_removed->insert(recv);
293   TF_RETURN_IF_ERROR(computation->RemoveInstruction(recv));
294 
295   return OkStatus();
296 }
297 
RemoveOutfeed(HloInstruction * outfeed_hlo) const298 Status HloControlFlowFlattening::RemoveOutfeed(
299     HloInstruction* outfeed_hlo) const {
300   CHECK_EQ(outfeed_hlo->opcode(), HloOpcode::kOutfeed);
301   HloComputation* computation = outfeed_hlo->parent();
302   // Replace the outfeed with a no-op custom call with side effect to ensure the
303   // operands aren't DCE'd.
304   HloInstruction* custom_call =
305       computation->AddInstruction(HloInstruction::CreateCustomCall(
306           outfeed_hlo->shape(), outfeed_hlo->operands(), "NopReturnToken"));
307   Cast<HloCustomCallInstruction>(custom_call)
308       ->set_custom_call_has_side_effect(true);
309   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(outfeed_hlo, custom_call));
310 
311   return OkStatus();
312 }
313 
RemoveSendDone(HloInstruction * send_done,absl::flat_hash_set<HloInstruction * > * additional_removed) const314 Status HloControlFlowFlattening::RemoveSendDone(
315     HloInstruction* send_done,
316     absl::flat_hash_set<HloInstruction*>* additional_removed) const {
317   CHECK_EQ(send_done->opcode(), HloOpcode::kSendDone);
318   CHECK_EQ(send_done->operand_count(), 1);
319   HloInstruction* send = send_done->mutable_operand(0);
320   CHECK_EQ(send->opcode(), HloOpcode::kSend);
321 
322   HloComputation* computation = send_done->parent();
323   HloInstruction* custom_call =
324       computation->AddInstruction(HloInstruction::CreateCustomCall(
325           send_done->shape(), send_done->operand(0)->operands(),
326           "NopReturnToken"));
327   Cast<HloCustomCallInstruction>(custom_call)
328       ->set_custom_call_has_side_effect(true);
329 
330   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(send_done, custom_call));
331   additional_removed->insert(send);
332   TF_RETURN_IF_ERROR(computation->RemoveInstruction(send));
333 
334   return OkStatus();
335 }
336 
RemoveCollective(HloInstruction * hlo) const337 Status HloControlFlowFlattening::RemoveCollective(HloInstruction* hlo) const {
338   HloComputation* computation = hlo->parent();
339   HloInstruction* custom_call =
340       computation->AddInstruction(HloInstruction::CreateCustomCall(
341           hlo->shape(), hlo->operands(), kNopCustomCallTarget));
342   // Copy backend config. This is necessary for a collective op in megacore
343   // fusion.
344   custom_call->CopyBackendConfigFrom(hlo);
345   auto replaced_collective_op_str =
346       hlo->ToString(HloPrintOptions().Canonical());
347   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, custom_call));
348   custom_call->set_metadata_replaced_op(replaced_collective_op_str);
349   return OkStatus();
350 }
351 
RemovePartitionOrReplicaId(HloInstruction * hlo) const352 Status HloControlFlowFlattening::RemovePartitionOrReplicaId(
353     HloInstruction* hlo) const {
354   HloComputation* computation = hlo->parent();
355   HloInstruction* zero = CreateConstant(hlo->shape(), computation);
356   TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, zero));
357   return OkStatus();
358 }
359 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)360 StatusOr<bool> HloControlFlowFlattening::Run(
361     HloModule* module,
362     const absl::flat_hash_set<absl::string_view>& execution_threads) {
363   auto call_graph = CallGraph::Build(module);
364   bool changed = false;
365   absl::flat_hash_set<HloInstruction*> removed;
366   for (HloComputation* computation : module->computations(execution_threads)) {
367     for (HloInstruction* instruction :
368          computation->MakeInstructionPostOrder()) {
369       if (removed.contains(instruction)) {
370         // Skip the instruction if it is already removed.
371         continue;
372       }
373       if (flatten_while_loop_ && instruction->opcode() == HloOpcode::kWhile) {
374         VLOG(1) << "Remove " << instruction->name();
375         TF_RETURN_IF_ERROR(FlattenWhileLoop(instruction, *call_graph));
376         changed = true;
377       } else if (remove_infeed_outfeed_ &&
378                  instruction->opcode() == HloOpcode::kInfeed) {
379         VLOG(1) << "Remove " << instruction->name();
380         TF_RETURN_IF_ERROR(RemoveInfeed(instruction));
381         changed = true;
382       } else if (remove_infeed_outfeed_ &&
383                  instruction->opcode() == HloOpcode::kOutfeed) {
384         VLOG(1) << "Remove " << instruction->name();
385         TF_RETURN_IF_ERROR(RemoveOutfeed(instruction));
386         changed = true;
387       } else if (instruction->opcode() == HloOpcode::kSendDone) {
388         auto send_done_instruction =
389             DynCast<HloSendDoneInstruction>(instruction);
390         CHECK(send_done_instruction);
391         if (remove_comm_ || (remove_host_transfer_ &&
392                              send_done_instruction->is_host_transfer())) {
393           VLOG(1) << "Remove " << instruction->name();
394           TF_RETURN_IF_ERROR(RemoveSendDone(instruction, &removed));
395           changed = true;
396         }
397       } else if (instruction->opcode() == HloOpcode::kRecvDone) {
398         auto recv_done_instruction =
399             DynCast<HloRecvDoneInstruction>(instruction);
400         CHECK(recv_done_instruction);
401         if (remove_comm_ || (remove_host_transfer_ &&
402                              recv_done_instruction->is_host_transfer())) {
403           VLOG(1) << "Remove " << instruction->name();
404           TF_RETURN_IF_ERROR(RemoveRecvDone(instruction, &removed));
405           changed = true;
406         }
407       } else if (remove_comm_ && IsCollective(instruction) &&
408                  !instruction->parent()->IsFusionComputation()) {
409         VLOG(1) << "Remove " << instruction->name();
410         TF_RETURN_IF_ERROR(RemoveCollective(instruction));
411         changed = true;
412       } else if (remove_comm_ &&
413                  (instruction->opcode() == HloOpcode::kPartitionId ||
414                   instruction->opcode() == HloOpcode::kReplicaId)) {
415         VLOG(1) << "Remove " << instruction->name();
416         TF_RETURN_IF_ERROR(RemovePartitionOrReplicaId(instruction));
417       }
418     }
419   }
420 
421   HloDCE hlo_dce;
422   TF_ASSIGN_OR_RETURN(bool dce_changed, hlo_dce.Run(module, execution_threads));
423   changed |= dce_changed;
424 
425   // Fix the schedule if the module was scheduled.
426   if (changed && module->has_schedule()) {
427     TF_RETURN_IF_ERROR(module->schedule().Update());
428   }
429   XLA_VLOG_LINES(3, module->ToString());
430   return changed;
431 }
432 
433 }  // namespace xla
434