xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tools/hlo_control_flow_flattening_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h"
17 
18 #include "absl/strings/str_replace.h"
19 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
20 #include "tensorflow/compiler/xla/service/despecializer.h"
21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
22 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace xla {
27 namespace {
28 
29 namespace op = xla::testing::opcode_matchers;
30 
31 using HloControlFlowFlatteningTest = HloTestBase;
32 
33 constexpr int kDefaultMaxLoopCount = 1000;
34 
TEST_F(HloControlFlowFlatteningTest,WhileRoot)35 TEST_F(HloControlFlowFlatteningTest, WhileRoot) {
36   absl::string_view hlo_string = R"(
37   HloModule While
38   While.body {
39     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
40     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
41     constant.1 = s32[] constant(1)
42     add = s32[] add(get-tuple-element.1, constant.1)
43     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
44     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
45     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
46   }
47   While.condition {
48     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
49     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
50     constant.2 = s32[] constant(100)
51     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
52   }
53   ENTRY While {
54     constant.3 = s32[] constant(42)
55     constant.4 = s32[3]{0} constant({0, 1, 2})
56     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
57     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
58   }
59   )";
60   TF_ASSERT_OK_AND_ASSIGN(auto module,
61                           ParseAndReturnVerifiedModule(hlo_string));
62   HloControlFlowFlattening flattening(
63       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
64   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
65   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
66                            /*allow_mixed_precision=*/true)
67                    .Run(module.get())
68                    .status());
69 
70   auto root = module->entry_computation()->root_instruction();
71   auto while_op = module->entry_computation()->GetInstructionWithName("while");
72   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(while_op, 0),
73                               op::GetTupleElement(while_op, 1)));
74   EXPECT_THAT(while_op,
75               op::While(op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
76                                   op::Constant())));
77   auto condition = while_op->while_condition();
78   EXPECT_THAT(
79       condition->root_instruction(),
80       op::Compare(op::GetTupleElement(op::Parameter(0), 2), op::Constant()));
81 
82   auto body = while_op->while_body();
83   EXPECT_THAT(body->root_instruction(),
84               op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
85                         op::Add(op::GetTupleElement(op::Parameter(0), 2),
86                                 op::Constant())));
87 }
88 
TEST_F(HloControlFlowFlatteningTest,WhileConditionCallComputation)89 TEST_F(HloControlFlowFlatteningTest, WhileConditionCallComputation) {
90   absl::string_view hlo_string = R"(
91   HloModule While
92   While.body {
93     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
94     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
95     constant.1 = s32[] constant(1)
96     add = s32[] add(get-tuple-element.1, constant.1)
97     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
98     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
99     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
100   }
101   While.condition.called {
102     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
103     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
104     constant.2 = s32[] custom-call(), custom_call_target="AllocateBuffer", custom_call_has_side_effect=true
105     less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
106     ROOT tuple.2 = (pred[]) tuple(less-than)
107   }
108   While.condition {
109     loop_var.3 = (s32[], s32[3]{0}) parameter(0)
110     call = (pred[]) call(loop_var.3), to_apply=While.condition.called
111     ROOT get-tuple-element.4 = pred[] get-tuple-element(call), index=0
112   }
113   ENTRY While {
114     constant.3 = s32[] constant(42)
115     constant.4 = s32[3]{0} constant({0, 1, 2})
116     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
117     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
118   }
119   )";
120   TF_ASSERT_OK_AND_ASSIGN(auto module,
121                           ParseAndReturnVerifiedModule(hlo_string));
122   HloControlFlowFlattening flattening(
123       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
124   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
125   XLA_VLOG_LINES(3, "Loaded HLO module: " + module->ToString());
126   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
127                            /*allow_mixed_precision=*/true)
128                    .Run(module.get())
129                    .status());
130 
131   auto root = module->entry_computation()->root_instruction();
132   auto while_op = module->entry_computation()->GetInstructionWithName("while");
133   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(while_op, 0),
134                               op::GetTupleElement(while_op, 1)));
135   EXPECT_THAT(while_op,
136               op::While(op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
137                                   op::Constant())));
138   auto condition = while_op->while_condition();
139   EXPECT_THAT(
140       condition->root_instruction(),
141       op::Compare(op::GetTupleElement(op::Parameter(0), 2), op::Constant()));
142 
143   auto body = while_op->while_body();
144   EXPECT_THAT(body->root_instruction(),
145               op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
146                         op::Add(op::GetTupleElement(op::Parameter(0), 2),
147                                 op::Constant())));
148 }
149 
TEST_F(HloControlFlowFlatteningTest,WhileRootScheduled)150 TEST_F(HloControlFlowFlatteningTest, WhileRootScheduled) {
151   absl::string_view hlo_string = R"(
152   HloModule While, is_scheduled=true
153   While.body {
154     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
155     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
156     constant.1 = s32[] constant(1)
157     add = s32[] add(get-tuple-element.1, constant.1)
158     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
159     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
160     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
161   }
162   While.condition {
163     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
164     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
165     constant.2 = s32[] constant(100)
166     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
167   }
168   ENTRY While {
169     constant.3 = s32[] constant(42)
170     constant.4 = s32[3]{0} constant({0, 1, 2})
171     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
172     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
173   }
174   )";
175   TF_ASSERT_OK_AND_ASSIGN(auto module,
176                           ParseAndReturnVerifiedModule(hlo_string));
177   HloControlFlowFlattening flattening(
178       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
179   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
180   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
181                            /*allow_mixed_precision=*/true)
182                    .Run(module.get())
183                    .status());
184 
185   auto root = module->entry_computation()->root_instruction();
186   auto while_op = module->entry_computation()->GetInstructionWithName("while");
187   EXPECT_THAT(root, op::Tuple(op::GetTupleElement(while_op, 0),
188                               op::GetTupleElement(while_op, 1)));
189   EXPECT_THAT(while_op,
190               op::While(op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
191                                   op::Constant())));
192   auto condition = while_op->while_condition();
193   EXPECT_THAT(
194       condition->root_instruction(),
195       op::Compare(op::GetTupleElement(op::Parameter(0), 2), op::Constant()));
196 }
197 
TEST_F(HloControlFlowFlatteningTest,WhileUser)198 TEST_F(HloControlFlowFlatteningTest, WhileUser) {
199   absl::string_view hlo_string = R"(
200   HloModule While
201   While.body {
202     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
203     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
204     constant.1 = s32[] constant(1)
205     add = s32[] add(get-tuple-element.1, constant.1)
206     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
207     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
208     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
209   }
210   While.condition {
211     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
212     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
213     constant.2 = s32[] constant(100)
214     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
215   }
216   FusedComputation {
217     param = (s32[], s32[3]{0}) parameter(0)
218     get-tuple-element.4 = s32[] get-tuple-element(param), index=0
219     get-tuple-element.5 = s32[3]{0} get-tuple-element(param), index=1
220     broadcast = s32[3]{0} broadcast(get-tuple-element.4), dimensions={}
221     ROOT add = s32[3]{0} add(broadcast, get-tuple-element.5)
222   }
223   ENTRY While {
224     constant.3 = s32[] constant(42)
225     constant.4 = s32[3]{0} constant({0, 1, 2})
226     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
227     while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
228     ROOT fusion = s32[3]{0} fusion(while), kind=kLoop, calls=FusedComputation
229   }
230   )";
231   TF_ASSERT_OK_AND_ASSIGN(auto module,
232                           ParseAndReturnVerifiedModule(hlo_string));
233   HloControlFlowFlattening flattening(
234       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
235   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
236   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
237                            /*allow_mixed_precision=*/true)
238                    .Run(module.get())
239                    .status());
240 
241   auto fusion = module->entry_computation()->root_instruction();
242   auto while_op = module->entry_computation()->GetInstructionWithName("while");
243   EXPECT_THAT(fusion, op::Fusion(op::Tuple(op::GetTupleElement(while_op, 0),
244                                            op::GetTupleElement(while_op, 1))));
245 }
246 
TEST_F(HloControlFlowFlatteningTest,Infeed)247 TEST_F(HloControlFlowFlatteningTest, Infeed) {
248   absl::string_view hlo_string = R"(
249   HloModule Infeed
250   ENTRY Infeed {
251     after-all = token[] after-all()
252     ROOT infeed = ((bf16[3]{0}, s32[12,5]{0,1}), token[]) infeed(after-all)
253   }
254   )";
255   TF_ASSERT_OK_AND_ASSIGN(auto module,
256                           ParseAndReturnVerifiedModule(hlo_string));
257   HloControlFlowFlattening flattening(
258       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
259   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
260   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
261                            /*allow_mixed_precision=*/true)
262                    .Run(module.get())
263                    .status());
264   auto tuple = module->entry_computation()->root_instruction();
265   EXPECT_THAT(tuple, op::Tuple(op::CustomCall(), op::AfterAll()));
266 }
267 
TEST_F(HloControlFlowFlatteningTest,InfeedPreserveLayout)268 TEST_F(HloControlFlowFlatteningTest, InfeedPreserveLayout) {
269   absl::string_view hlo_string = R"(
270   HloModule Infeed
271   ENTRY Infeed {
272     after-all = token[] after-all()
273     ROOT infeed = ((bf16[3]{0}, s32[12,5]{0,1:T(8,128)}), token[]) infeed(after-all)
274   }
275   )";
276   TF_ASSERT_OK_AND_ASSIGN(auto module,
277                           ParseAndReturnVerifiedModule(hlo_string));
278   Shape root_shape = module->entry_computation()->root_instruction()->shape();
279   HloControlFlowFlattening flattening(
280       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
281   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
282   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
283                            /*allow_mixed_precision=*/true)
284                    .Run(module.get())
285                    .status());
286   auto tuple = module->entry_computation()->root_instruction();
287   EXPECT_THAT(tuple, op::Tuple(op::CustomCall(), op::AfterAll()));
288   EXPECT_EQ(tuple->shape(), root_shape);
289 }
290 
TEST_F(HloControlFlowFlatteningTest,Outfeed)291 TEST_F(HloControlFlowFlatteningTest, Outfeed) {
292   absl::string_view hlo_string = R"(
293   HloModule Outfeed
294   ENTRY Outfeed {
295     param = (bf16[3]{0}, s32[12,5]{0,1}) parameter(0)
296     after-all = token[] after-all()
297     ROOT outfeed = token[] outfeed(param, after-all)
298   }
299   )";
300   TF_ASSERT_OK_AND_ASSIGN(auto module,
301                           ParseAndReturnVerifiedModule(hlo_string));
302   HloControlFlowFlattening flattening(
303       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
304   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
305   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
306                            /*allow_mixed_precision=*/true)
307                    .Run(module.get())
308                    .status());
309   auto custom_call = module->entry_computation()->root_instruction();
310   EXPECT_THAT(custom_call, op::CustomCall(op::Parameter(0), op::AfterAll()));
311 }
312 
TEST_F(HloControlFlowFlatteningTest,AllReduce)313 TEST_F(HloControlFlowFlatteningTest, AllReduce) {
314   absl::string_view hlo_string = R"(
315   HloModule AllReduce
316   sum {
317     p0 = f32[] parameter(0)
318     p1 = f32[] parameter(1)
319     ROOT add = f32[] add(p0, p1)
320   }
321 
322   ENTRY AllReduce {
323     param0 = f32[3]{0} parameter(0)
324     param1 = f32[12,5]{0,1} parameter(1)
325     ROOT all-reduce = (bf16[3]{0}, bf16[12,5]{0,1}) all-reduce(param0, param1), to_apply=sum, replica_groups={}
326   }
327   )";
328   TF_ASSERT_OK_AND_ASSIGN(auto module,
329                           ParseAndReturnVerifiedModule(hlo_string));
330   HloControlFlowFlattening flattening(
331       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
332   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
333   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
334                            /*allow_mixed_precision=*/true)
335                    .Run(module.get())
336                    .status());
337   LOG(INFO) << module->ToString();
338   EXPECT_THAT(module->entry_computation()->root_instruction(),
339               op::CustomCall(op::Parameter(0), op::Parameter(1)));
340 }
341 
TEST_F(HloControlFlowFlatteningTest,AllReduceStartAndDone)342 TEST_F(HloControlFlowFlatteningTest, AllReduceStartAndDone) {
343   absl::string_view hlo_string = R"(
344   HloModule CRS
345 
346   add {
347     lhs = f32[] parameter(0)
348     rhs = f32[] parameter(1)
349     ROOT add = f32[] add(lhs, rhs)
350   }
351 
352   ENTRY CRS {
353     input = f32[8]{0} parameter(0)
354     crs = f32[8]{0} all-reduce-start(input), replica_groups={}, to_apply=add
355     ROOT done = f32[8]{0} all-reduce-done(crs)
356   }
357   )";
358   TF_ASSERT_OK_AND_ASSIGN(auto module,
359                           ParseAndReturnVerifiedModule(hlo_string));
360   HloControlFlowFlattening flattening(
361       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
362   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
363   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
364                            /*allow_mixed_precision=*/true)
365                    .Run(module.get())
366                    .status());
367   LOG(INFO) << module->ToString();
368   EXPECT_THAT(module->entry_computation()->root_instruction(),
369               op::CustomCall(op::CustomCall(op::Parameter(0))));
370 }
371 
TEST_F(HloControlFlowFlatteningTest,AllGather)372 TEST_F(HloControlFlowFlatteningTest, AllGather) {
373   absl::string_view hlo_string = R"(
374   HloModule AllGather
375 
376   ENTRY AllGather {
377     input = f32[128,32]{0,1} parameter(0)
378     ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1}
379   }
380   )";
381   TF_ASSERT_OK_AND_ASSIGN(auto module,
382                           ParseAndReturnVerifiedModule(hlo_string));
383   HloControlFlowFlattening flattening(
384       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
385   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
386   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
387                            /*allow_mixed_precision=*/true)
388                    .Run(module.get())
389                    .status());
390   LOG(INFO) << module->ToString();
391   EXPECT_THAT(module->entry_computation()->root_instruction(),
392               op::CustomCall(op::Parameter(0)));
393 }
394 
TEST_F(HloControlFlowFlatteningTest,AllToAll)395 TEST_F(HloControlFlowFlatteningTest, AllToAll) {
396   absl::string_view hlo_string = R"(
397   HloModule AllToAll
398 
399   ENTRY AllToAll {
400     input = f32[128,32]{0,1} parameter(0)
401     ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
402   }
403   )";
404   TF_ASSERT_OK_AND_ASSIGN(auto module,
405                           ParseAndReturnVerifiedModule(hlo_string));
406   HloControlFlowFlattening flattening(
407       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
408   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
409   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
410                            /*allow_mixed_precision=*/true)
411                    .Run(module.get())
412                    .status());
413   LOG(INFO) << module->ToString();
414   EXPECT_THAT(module->entry_computation()->root_instruction(),
415               op::CustomCall(op::Parameter(0)));
416 }
417 
TEST_F(HloControlFlowFlatteningTest,CollectivePermute)418 TEST_F(HloControlFlowFlatteningTest, CollectivePermute) {
419   absl::string_view hlo_string = R"(
420   HloModule CollectivePermute
421 
422   ENTRY CollectivePermute {
423     input = f32[128,32]{0,1} parameter(0)
424     ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
425   }
426   )";
427   TF_ASSERT_OK_AND_ASSIGN(auto module,
428                           ParseAndReturnVerifiedModule(hlo_string));
429   HloControlFlowFlattening flattening(
430       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
431   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
432   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
433                            /*allow_mixed_precision=*/true)
434                    .Run(module.get())
435                    .status());
436   LOG(INFO) << module->ToString();
437   EXPECT_THAT(module->entry_computation()->root_instruction(),
438               op::CustomCall(op::Parameter(0)));
439 }
440 
TEST_F(HloControlFlowFlatteningTest,CollectivePermuteInPlaceUpdate)441 TEST_F(HloControlFlowFlatteningTest, CollectivePermuteInPlaceUpdate) {
442   absl::string_view hlo_string = R"(
443   HloModule CollectivePermuteInPlaceUpdate
444 
445   ENTRY CollectivePermuteInPlaceUpdate {
446     input = f32[128,32]{0,1} parameter(0)
447     constant = f32[] constant(1)
448     output = f32[128,128]{0,1} broadcast(constant), dimensions={}
449     constant.1 = s32[] constant(0)
450     tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
451     constant.2 = s32[] constant(64)
452     tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
453     ROOT root = f32[128,128]{0,1} collective-permute(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{128,32}}
454   }
455   )";
456   TF_ASSERT_OK_AND_ASSIGN(auto module,
457                           ParseAndReturnVerifiedModule(hlo_string));
458   HloControlFlowFlattening flattening(
459       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
460   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
461   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
462                            /*allow_mixed_precision=*/true)
463                    .Run(module.get())
464                    .status());
465   LOG(INFO) << module->ToString();
466   EXPECT_THAT(module->entry_computation()->root_instruction(),
467               op::CustomCall(op::Parameter(0), op::Broadcast(op::Constant()),
468                              op::Tuple(op::Constant(), op::Constant()),
469                              op::Tuple(op::Constant(), op::Constant())));
470 }
471 
TEST_F(HloControlFlowFlatteningTest,CollectivePermuteStartAndDone)472 TEST_F(HloControlFlowFlatteningTest, CollectivePermuteStartAndDone) {
473   absl::string_view hlo_string = R"(
474   HloModule CollectivePermuteStartAndDone
475 
476   ENTRY CollectivePermuteStartAndDone {
477     input = f32[128,32]{0,1} parameter(0)
478     collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
479     ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
480   }
481   )";
482   TF_ASSERT_OK_AND_ASSIGN(auto module,
483                           ParseAndReturnVerifiedModule(hlo_string));
484   HloControlFlowFlattening flattening(
485       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
486   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
487   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
488                            /*allow_mixed_precision=*/true)
489                    .Run(module.get())
490                    .status());
491   LOG(INFO) << module->ToString();
492   EXPECT_THAT(module->entry_computation()->root_instruction(),
493               op::CustomCall(op::CustomCall(op::Parameter(0))));
494 }
495 
TEST_F(HloControlFlowFlatteningTest,Recv)496 TEST_F(HloControlFlowFlatteningTest, Recv) {
497   absl::string_view hlo_string = R"(
498   HloModule Recv
499 
500   ENTRY %Recv () -> (f32[], token[]) {
501     %token0 = token[] after-all()
502     %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
503     ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
504     %constant = f32[] constant(2.1), sharding={maximal device=0}
505     %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
506     %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}  }
507   )";
508   TF_ASSERT_OK_AND_ASSIGN(auto module,
509                           ParseAndReturnVerifiedModule(hlo_string));
510   ControlDepRemover control_remover;
511   HloControlFlowFlattening flattening(
512       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
513   TF_ASSERT_OK(control_remover.Run(module.get()).status());
514   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
515   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
516                            /*allow_mixed_precision=*/true)
517                    .Run(module.get())
518                    .status());
519   LOG(INFO) << module->ToString();
520   EXPECT_THAT(module->entry_computation()->root_instruction(),
521               op::Tuple(op::CustomCall(), op::AfterAll()));
522 }
523 
TEST_F(HloControlFlowFlatteningTest,RecvHostTransfer)524 TEST_F(HloControlFlowFlatteningTest, RecvHostTransfer) {
525   absl::string_view hlo_string = R"(
526   HloModule Recv
527 
528   ENTRY %Recv () -> (f32[], token[]) {
529     %token0 = token[] after-all()
530     %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true, sharding={maximal device=1}
531     ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true, sharding={maximal device=1}
532     %constant = f32[] constant(2.1), sharding={maximal device=0}
533     %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
534     %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}  }
535   )";
536   TF_ASSERT_OK_AND_ASSIGN(auto module,
537                           ParseAndReturnVerifiedModule(hlo_string));
538   ControlDepRemover control_remover;
539   HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
540       /*while_execution_count=*/3, /*max_outer_loop_count=*/3,
541       /*max_loop_count=*/3, /*remove_infeed_outfeed=*/true,
542       /*flatten_while_loop=*/true, /*remove_comm=*/false,
543       /*remove_host_transfer=*/true});
544   TF_ASSERT_OK(control_remover.Run(module.get()).status());
545   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
546   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
547                            /*allow_mixed_precision=*/true)
548                    .Run(module.get())
549                    .status());
550   LOG(INFO) << module->ToString();
551   EXPECT_THAT(module->entry_computation()->root_instruction(),
552               op::Tuple(op::CustomCall(), op::AfterAll()));
553 }
554 
TEST_F(HloControlFlowFlatteningTest,Send)555 TEST_F(HloControlFlowFlatteningTest, Send) {
556   absl::string_view hlo_string = R"(
557   HloModule Send
558 
559   ENTRY %Send () -> token[] {
560     %token0 = token[] after-all()
561     %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
562     %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
563     %constant = f32[] constant(2.1), sharding={maximal device=0}
564     %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
565     ROOT %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
566   }
567   )";
568   TF_ASSERT_OK_AND_ASSIGN(auto module,
569                           ParseAndReturnVerifiedModule(hlo_string));
570   ControlDepRemover control_remover;
571   HloControlFlowFlattening flattening(
572       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
573   TF_ASSERT_OK(control_remover.Run(module.get()).status());
574   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
575   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
576                            /*allow_mixed_precision=*/true)
577                    .Run(module.get())
578                    .status());
579   LOG(INFO) << module->ToString();
580   EXPECT_THAT(module->entry_computation()->root_instruction(),
581               op::CustomCall(op::Constant(), op::AfterAll()));
582 }
583 
TEST_F(HloControlFlowFlatteningTest,SendHostTransfer)584 TEST_F(HloControlFlowFlatteningTest, SendHostTransfer) {
585   absl::string_view hlo_string = R"(
586   HloModule Send
587 
588   ENTRY %Send () -> token[] {
589     %token0 = token[] after-all()
590     %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
591     %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
592     %constant = f32[] constant(2.1), sharding={maximal device=0}
593     %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true, sharding={maximal device=0}, control-predecessors={%recv}
594     ROOT %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true, sharding={maximal device=0}
595   }
596   )";
597   TF_ASSERT_OK_AND_ASSIGN(auto module,
598                           ParseAndReturnVerifiedModule(hlo_string));
599   ControlDepRemover control_remover;
600   HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
601       /*while_execution_count=*/3, /*max_outer_loop_count=*/3,
602       /*max_loop_count=*/3, /*remove_infeed_outfeed=*/true,
603       /*flatten_while_loop=*/true, /*remove_comm=*/false,
604       /*remove_host_transfer=*/true});
605   TF_ASSERT_OK(control_remover.Run(module.get()).status());
606   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
607   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
608                            /*allow_mixed_precision=*/true)
609                    .Run(module.get())
610                    .status());
611   LOG(INFO) << module->ToString();
612   EXPECT_THAT(module->entry_computation()->root_instruction(),
613               op::CustomCall(op::Constant(), op::AfterAll()));
614 }
615 
TEST_F(HloControlFlowFlatteningTest,AllGatherStartAndDone)616 TEST_F(HloControlFlowFlatteningTest, AllGatherStartAndDone) {
617   absl::string_view hlo_string = R"(
618   HloModule AllGatherStartAndDone
619 
620   ENTRY AllGatherStartAndDone {
621     %input = f32[8,256,256] parameter(0)
622     %ag-start = (f32[8,256,256], f32[16,256,256]) all-gather-start(
623       f32[8,256,256] %input), replica_groups={{0,1}}, dimensions={0},
624       metadata={op_type="AllGather" op_name="ag0"}
625     ROOT %ag-done = f32[16,256,256] all-gather-done(
626       (f32[8,256,256], f32[16,256,256]) %ag-start),
627       metadata={op_type="AllGather" op_name="ag0"}
628   }
629   )";
630   TF_ASSERT_OK_AND_ASSIGN(auto module,
631                           ParseAndReturnVerifiedModule(hlo_string));
632   HloControlFlowFlattening flattening(
633       HloControlFlowFlattening::Options{/*while_execution_count=*/3});
634   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
635   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
636                            /*allow_mixed_precision=*/true)
637                    .Run(module.get())
638                    .status());
639   LOG(INFO) << module->ToString();
640   EXPECT_THAT(module->entry_computation()->root_instruction(),
641               op::CustomCall(op::CustomCall(op::Parameter(0))));
642 }
643 
TEST_F(HloControlFlowFlatteningTest,CollectiveFusion)644 TEST_F(HloControlFlowFlatteningTest, CollectiveFusion) {
645   absl::string_view hlo_template = R"(
646 HloModule collective-fusion, is_scheduled=true
647 
648 %sum (a: f32[], b: f32[]) -> f32[] {
649   %a = f32[] parameter(0)
650   %b = f32[] parameter(1)
651   ROOT %add = f32[] add(f32[] a, f32[] b)
652 }
653 
654 %all-gather {
655   %constant.3 = f32[] constant(0)
656   %broadcast = f32[full_size,8,128]{2,1,0} broadcast(%constant.3), dimensions={}
657   %input.0 = f32[4,8,128]{2,1,0} parameter(0)
658   %input.1 = f32[4,8,128]{2,1,0} parameter(1)
659   %replica-id.1 = u32[] replica-id()
660   %constant.4 = u32[] constant(4)
661   %multiply.1 = u32[] multiply(%replica-id.1, %constant.4)
662   %constant.5 = u32[] constant(0)
663   %constant.6 = u32[] constant(0)
664   %dynamic-update-slice = f32[full_size,8,128]{2,1,0} dynamic-update-slice(%broadcast, %input.0, %multiply.1, %constant.5, %constant.6)
665   %dynamic-update-slice.1 = f32[full_size,8,128]{2,1,0} dynamic-update-slice(%broadcast, %input.1, %multiply.1, %constant.5, %constant.6)
666   %all-reduce = (f32[full_size,8,128]{2,1,0}, f32[full_size,8,128]{2,1,0}) all-reduce(%dynamic-update-slice,  %dynamic-update-slice.1), replica_groups={}, backend_config="{barrier_config:{barrier_type:3,id:0}}", to_apply=%sum
667   %gte0 = f32[full_size,8,128]{2,1,0} get-tuple-element(%all-reduce), index=0
668   %slice = f32[unpadded_size,8,128]{2,1,0} slice(%gte0), slice={[0:unpadded_size], [0:8], [0:128]}
669   %bitcast = f32[unpadded_size,1,8,128]{3,2,1,0} bitcast(%slice)
670   %gte1 = f32[full_size,8,128]{2,1,0} get-tuple-element(%all-reduce), index=1
671   ROOT %tuple = (f32[unpadded_size,1,8,128]{3,2,1,0}, f32[full_size,8,128]{2,1,0}) tuple(%bitcast, %gte1)
672 }
673 
674 ENTRY main {
675   %add.1 = f32[4,8,128]{2,1,0} parameter(0)
676   %add.2 = f32[4,8,128]{2,1,0} parameter(1)
677   ROOT %fusion = (f32[unpadded_size,1,8,128]{3,2,1,0}, f32[full_size,8,128]{2,1,0}) fusion(%add.1, %add.2), kind=kCustom, calls=%all-gather
678 }
679   )";
680   auto hlo_string = absl::StrReplaceAll(
681       hlo_template, {{"full_size", absl::StrCat(12288)},
682                      {"unpadded_size", absl::StrCat(12285)}});
683   TF_ASSERT_OK_AND_ASSIGN(auto module,
684                           ParseAndReturnVerifiedModule(hlo_string));
685   EXPECT_TRUE(IsCollective(module->entry_computation()->root_instruction()));
686 
687   HloControlFlowFlattening flattening({});
688   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
689   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
690                            /*allow_mixed_precision=*/true)
691                    .Run(module.get())
692                    .status());
693   LOG(INFO) << module->ToString();
694   EXPECT_THAT(module->entry_computation()->root_instruction(),
695               op::CustomCall(op::Parameter(0), op::Parameter(1)));
696 }
697 
CheckWhileBound(HloInstruction * while_op,int expected_bound)698 void CheckWhileBound(HloInstruction* while_op, int expected_bound) {
699   auto* cond = while_op->while_condition();
700   ASSERT_NE(cond, nullptr);
701   auto* hlo_bound = cond->root_instruction()->operand(1);
702   EXPECT_TRUE(hlo_bound->IsConstant());
703   if (hlo_bound->IsConstant()) {
704     EXPECT_TRUE(hlo_bound->literal().IsAll(expected_bound));
705   }
706 }
707 
TEST_F(HloControlFlowFlatteningTest,MaxOuterLoopCount)708 TEST_F(HloControlFlowFlatteningTest, MaxOuterLoopCount) {
709   absl::string_view hlo_string = R"(
710   HloModule NestedWhileComp
711 
712   InnerBody {
713     constant.8 = pred[] constant(false)
714     parameter.5 = (s32[], s32[]) parameter(0)
715     get-tuple-element.6 = s32[] get-tuple-element(parameter.5), index=0
716     constant.9 = s32[] constant(1)
717     add.10 = s32[] add(get-tuple-element.6, constant.9)
718     get-tuple-element.7 = s32[] get-tuple-element(parameter.5), index=1
719     constant.11 = s32[] constant(1)
720     add.12 = s32[] add(get-tuple-element.7, constant.11)
721     ROOT tuple.13 = (s32[], s32[]) tuple(add.10, add.12)
722   }
723 
724   InnerCond {
725     parameter.15 = (s32[], s32[]) parameter(0)
726     get-tuple-element.17 = s32[] get-tuple-element(parameter.15), index=1
727     constant.18 = pred[] constant(false)
728     get-tuple-element.16 = s32[] get-tuple-element(parameter.15), index=0
729     inner_bound = s32[] constant(100)
730     ROOT compare.20 = pred[] compare(get-tuple-element.16, inner_bound), direction=LT
731   }
732 
733   OuterBody {
734     constant.24 = pred[] constant(false)
735     constant.25 = s32[] constant(0)
736     parameter.22 = (s32[]) parameter(0)
737     get-tuple-element.23 = s32[] get-tuple-element(parameter.22), index=0
738     tuple.26 = (s32[], s32[]) tuple(constant.25, get-tuple-element.23)
739     inner_while = (s32[], s32[]) while(tuple.26), condition=InnerCond, body=InnerBody
740     get-tuple-element.28 = s32[] get-tuple-element(inner_while), index=0
741     get-tuple-element.29 = s32[] get-tuple-element(inner_while), index=1
742     tuple.30 = (s32[], s32[]) tuple(get-tuple-element.28, get-tuple-element.29)
743     get-tuple-element.31 = s32[] get-tuple-element(tuple.30), index=0
744     get-tuple-element.32 = s32[] get-tuple-element(tuple.30), index=1
745     ROOT tuple.33 = (s32[]) tuple(get-tuple-element.32)
746   }
747 
748   OuterCond {
749     constant.37 = pred[] constant(false)
750     parameter.35 = (s32[]) parameter(0)
751     get-tuple-element.36 = s32[] get-tuple-element(parameter.35), index=0
752     outer_bound = s32[] constant(1000)
753     ROOT compare.39 = pred[] compare(get-tuple-element.36, outer_bound), direction=LT
754   }
755 
756   ENTRY NestedWhileComp {
757     constant.1 = pred[] constant(false)
758     constant.2 = s32[] constant(0)
759     tuple.3 = (s32[]) tuple(constant.2)
760     outer_while = (s32[]) while(tuple.3), condition=OuterCond, body=OuterBody
761     get-tuple-element.41 = s32[] get-tuple-element(outer_while), index=0
762     tuple.42 = (s32[]) tuple(get-tuple-element.41)
763     get-tuple-element.43 = s32[] get-tuple-element(tuple.42), index=0
764     ROOT tuple.44 = (s32[]) tuple(get-tuple-element.43)
765   }
766   )";
767 
768   TF_ASSERT_OK_AND_ASSIGN(auto module,
769                           ParseAndReturnVerifiedModule(hlo_string));
770   constexpr int kWhileExecutionCount = 5;
771   constexpr int kExistingInnerLoopCount = 100;
772   constexpr int kMaxLoopCount = 10;
773   HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
774       /*while_execution_count=*/kWhileExecutionCount,
775       /*max_outer_loop_count=*/kMaxLoopCount});
776   EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
777   TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
778                            /*allow_mixed_precision=*/true)
779                    .Run(module.get())
780                    .status());
781   LOG(INFO) << module->ToString();
782 
783   auto* outer_while =
784       module->entry_computation()->GetInstructionWithName("outer_while");
785   ASSERT_NE(outer_while, nullptr);
786   // Checks that the outer while loop has changed its loop bound.
787   CheckWhileBound(outer_while, kMaxLoopCount);
788   auto* while_body = outer_while->while_body();
789   ASSERT_NE(while_body, nullptr);
790 
791   auto* inner_while = while_body->GetInstructionWithName("inner_while");
792   ASSERT_NE(inner_while, nullptr);
793   // Checks that the inner loop bound has not changed.
794   CheckWhileBound(inner_while, kExistingInnerLoopCount);
795 }
796 
TEST_F(HloControlFlowFlatteningTest,MatchLtUseInferedLoopCount)797 TEST_F(HloControlFlowFlatteningTest, MatchLtUseInferedLoopCount) {
798   absl::string_view hlo_string = R"(
799   HloModule While
800   While.body {
801     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
802     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
803     constant.1 = s32[] constant(1)
804     add = s32[] add(get-tuple-element.1, constant.1)
805     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
806     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
807     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
808   }
809   While.condition {
810     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
811     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
812     constant.2 = s32[] constant(100)
813     ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
814   }
815   ENTRY While {
816     constant.3 = s32[] constant(42)
817     constant.4 = s32[3]{0} constant({0, 1, 2})
818     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
819     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
820   }
821   )";
822   TF_ASSERT_OK_AND_ASSIGN(auto module,
823                           ParseAndReturnVerifiedModule(hlo_string));
824 
825   EXPECT_EQ(GetLoopBound(*module->entry_computation()->root_instruction(), 123,
826                          kDefaultMaxLoopCount),
827             100);
828 }
829 
TEST_F(HloControlFlowFlatteningTest,MatchGtUseInferedLoopCount)830 TEST_F(HloControlFlowFlatteningTest, MatchGtUseInferedLoopCount) {
831   absl::string_view hlo_string = R"(
832   HloModule While
833   While.body {
834     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
835     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
836     constant.1 = s32[] constant(1)
837     add = s32[] add(get-tuple-element.1, constant.1)
838     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
839     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
840     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
841   }
842   While.condition {
843     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
844     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
845     constant.2 = s32[] constant(50)
846     ROOT greater-than = pred[] compare(constant.2, get-tuple-element.3), direction=GT
847   }
848   ENTRY While {
849     constant.3 = s32[] constant(42)
850     constant.4 = s32[3]{0} constant({0, 1, 2})
851     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
852     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
853   }
854   )";
855   TF_ASSERT_OK_AND_ASSIGN(auto module,
856                           ParseAndReturnVerifiedModule(hlo_string));
857 
858   EXPECT_EQ(GetLoopBound(*module->entry_computation()->root_instruction(), 123,
859                          kDefaultMaxLoopCount),
860             50);
861 }
862 
TEST_F(HloControlFlowFlatteningTest,NotMatchEqUseDefaultLoopCount)863 TEST_F(HloControlFlowFlatteningTest, NotMatchEqUseDefaultLoopCount) {
864   absl::string_view hlo_string = R"(
865   HloModule While
866   While.body {
867     loop_var.1 = (s32[], s32[3]{0}) parameter(0)
868     get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
869     constant.1 = s32[] constant(1)
870     add = s32[] add(get-tuple-element.1, constant.1)
871     get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
872     multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
873     ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
874   }
875   While.condition {
876     loop_var.2 = (s32[], s32[3]{0}) parameter(0)
877     get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
878     constant.2 = s32[] constant(100)
879     ROOT equal = pred[] compare(get-tuple-element.3, constant.2), direction=EQ
880   }
881   ENTRY While {
882     constant.3 = s32[] constant(42)
883     constant.4 = s32[3]{0} constant({0, 1, 2})
884     tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
885     ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
886   }
887   )";
888   TF_ASSERT_OK_AND_ASSIGN(auto module,
889                           ParseAndReturnVerifiedModule(hlo_string));
890 
891   EXPECT_EQ(GetLoopBound(*module->entry_computation()->root_instruction(), 123,
892                          kDefaultMaxLoopCount),
893             123);
894 }
895 
896 }  // namespace
897 }  // namespace xla
898