1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tools/hlo_control_flow_flattening.h"
17
18 #include "absl/strings/str_replace.h"
19 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
20 #include "tensorflow/compiler/xla/service/despecializer.h"
21 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
22 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
23 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25
26 namespace xla {
27 namespace {
28
29 namespace op = xla::testing::opcode_matchers;
30
31 using HloControlFlowFlatteningTest = HloTestBase;
32
33 constexpr int kDefaultMaxLoopCount = 1000;
34
TEST_F(HloControlFlowFlatteningTest,WhileRoot)35 TEST_F(HloControlFlowFlatteningTest, WhileRoot) {
36 absl::string_view hlo_string = R"(
37 HloModule While
38 While.body {
39 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
40 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
41 constant.1 = s32[] constant(1)
42 add = s32[] add(get-tuple-element.1, constant.1)
43 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
44 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
45 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
46 }
47 While.condition {
48 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
49 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
50 constant.2 = s32[] constant(100)
51 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
52 }
53 ENTRY While {
54 constant.3 = s32[] constant(42)
55 constant.4 = s32[3]{0} constant({0, 1, 2})
56 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
57 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
58 }
59 )";
60 TF_ASSERT_OK_AND_ASSIGN(auto module,
61 ParseAndReturnVerifiedModule(hlo_string));
62 HloControlFlowFlattening flattening(
63 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
64 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
65 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
66 /*allow_mixed_precision=*/true)
67 .Run(module.get())
68 .status());
69
70 auto root = module->entry_computation()->root_instruction();
71 auto while_op = module->entry_computation()->GetInstructionWithName("while");
72 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(while_op, 0),
73 op::GetTupleElement(while_op, 1)));
74 EXPECT_THAT(while_op,
75 op::While(op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
76 op::Constant())));
77 auto condition = while_op->while_condition();
78 EXPECT_THAT(
79 condition->root_instruction(),
80 op::Compare(op::GetTupleElement(op::Parameter(0), 2), op::Constant()));
81
82 auto body = while_op->while_body();
83 EXPECT_THAT(body->root_instruction(),
84 op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
85 op::Add(op::GetTupleElement(op::Parameter(0), 2),
86 op::Constant())));
87 }
88
TEST_F(HloControlFlowFlatteningTest,WhileConditionCallComputation)89 TEST_F(HloControlFlowFlatteningTest, WhileConditionCallComputation) {
90 absl::string_view hlo_string = R"(
91 HloModule While
92 While.body {
93 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
94 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
95 constant.1 = s32[] constant(1)
96 add = s32[] add(get-tuple-element.1, constant.1)
97 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
98 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
99 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
100 }
101 While.condition.called {
102 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
103 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
104 constant.2 = s32[] custom-call(), custom_call_target="AllocateBuffer", custom_call_has_side_effect=true
105 less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
106 ROOT tuple.2 = (pred[]) tuple(less-than)
107 }
108 While.condition {
109 loop_var.3 = (s32[], s32[3]{0}) parameter(0)
110 call = (pred[]) call(loop_var.3), to_apply=While.condition.called
111 ROOT get-tuple-element.4 = pred[] get-tuple-element(call), index=0
112 }
113 ENTRY While {
114 constant.3 = s32[] constant(42)
115 constant.4 = s32[3]{0} constant({0, 1, 2})
116 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
117 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
118 }
119 )";
120 TF_ASSERT_OK_AND_ASSIGN(auto module,
121 ParseAndReturnVerifiedModule(hlo_string));
122 HloControlFlowFlattening flattening(
123 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
124 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
125 XLA_VLOG_LINES(3, "Loaded HLO module: " + module->ToString());
126 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
127 /*allow_mixed_precision=*/true)
128 .Run(module.get())
129 .status());
130
131 auto root = module->entry_computation()->root_instruction();
132 auto while_op = module->entry_computation()->GetInstructionWithName("while");
133 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(while_op, 0),
134 op::GetTupleElement(while_op, 1)));
135 EXPECT_THAT(while_op,
136 op::While(op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
137 op::Constant())));
138 auto condition = while_op->while_condition();
139 EXPECT_THAT(
140 condition->root_instruction(),
141 op::Compare(op::GetTupleElement(op::Parameter(0), 2), op::Constant()));
142
143 auto body = while_op->while_body();
144 EXPECT_THAT(body->root_instruction(),
145 op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
146 op::Add(op::GetTupleElement(op::Parameter(0), 2),
147 op::Constant())));
148 }
149
TEST_F(HloControlFlowFlatteningTest,WhileRootScheduled)150 TEST_F(HloControlFlowFlatteningTest, WhileRootScheduled) {
151 absl::string_view hlo_string = R"(
152 HloModule While, is_scheduled=true
153 While.body {
154 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
155 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
156 constant.1 = s32[] constant(1)
157 add = s32[] add(get-tuple-element.1, constant.1)
158 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
159 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
160 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
161 }
162 While.condition {
163 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
164 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
165 constant.2 = s32[] constant(100)
166 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
167 }
168 ENTRY While {
169 constant.3 = s32[] constant(42)
170 constant.4 = s32[3]{0} constant({0, 1, 2})
171 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
172 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
173 }
174 )";
175 TF_ASSERT_OK_AND_ASSIGN(auto module,
176 ParseAndReturnVerifiedModule(hlo_string));
177 HloControlFlowFlattening flattening(
178 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
179 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
180 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
181 /*allow_mixed_precision=*/true)
182 .Run(module.get())
183 .status());
184
185 auto root = module->entry_computation()->root_instruction();
186 auto while_op = module->entry_computation()->GetInstructionWithName("while");
187 EXPECT_THAT(root, op::Tuple(op::GetTupleElement(while_op, 0),
188 op::GetTupleElement(while_op, 1)));
189 EXPECT_THAT(while_op,
190 op::While(op::Tuple(op::GetTupleElement(), op::GetTupleElement(),
191 op::Constant())));
192 auto condition = while_op->while_condition();
193 EXPECT_THAT(
194 condition->root_instruction(),
195 op::Compare(op::GetTupleElement(op::Parameter(0), 2), op::Constant()));
196 }
197
TEST_F(HloControlFlowFlatteningTest,WhileUser)198 TEST_F(HloControlFlowFlatteningTest, WhileUser) {
199 absl::string_view hlo_string = R"(
200 HloModule While
201 While.body {
202 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
203 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
204 constant.1 = s32[] constant(1)
205 add = s32[] add(get-tuple-element.1, constant.1)
206 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
207 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
208 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
209 }
210 While.condition {
211 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
212 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
213 constant.2 = s32[] constant(100)
214 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
215 }
216 FusedComputation {
217 param = (s32[], s32[3]{0}) parameter(0)
218 get-tuple-element.4 = s32[] get-tuple-element(param), index=0
219 get-tuple-element.5 = s32[3]{0} get-tuple-element(param), index=1
220 broadcast = s32[3]{0} broadcast(get-tuple-element.4), dimensions={}
221 ROOT add = s32[3]{0} add(broadcast, get-tuple-element.5)
222 }
223 ENTRY While {
224 constant.3 = s32[] constant(42)
225 constant.4 = s32[3]{0} constant({0, 1, 2})
226 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
227 while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
228 ROOT fusion = s32[3]{0} fusion(while), kind=kLoop, calls=FusedComputation
229 }
230 )";
231 TF_ASSERT_OK_AND_ASSIGN(auto module,
232 ParseAndReturnVerifiedModule(hlo_string));
233 HloControlFlowFlattening flattening(
234 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
235 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
236 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
237 /*allow_mixed_precision=*/true)
238 .Run(module.get())
239 .status());
240
241 auto fusion = module->entry_computation()->root_instruction();
242 auto while_op = module->entry_computation()->GetInstructionWithName("while");
243 EXPECT_THAT(fusion, op::Fusion(op::Tuple(op::GetTupleElement(while_op, 0),
244 op::GetTupleElement(while_op, 1))));
245 }
246
TEST_F(HloControlFlowFlatteningTest,Infeed)247 TEST_F(HloControlFlowFlatteningTest, Infeed) {
248 absl::string_view hlo_string = R"(
249 HloModule Infeed
250 ENTRY Infeed {
251 after-all = token[] after-all()
252 ROOT infeed = ((bf16[3]{0}, s32[12,5]{0,1}), token[]) infeed(after-all)
253 }
254 )";
255 TF_ASSERT_OK_AND_ASSIGN(auto module,
256 ParseAndReturnVerifiedModule(hlo_string));
257 HloControlFlowFlattening flattening(
258 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
259 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
260 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
261 /*allow_mixed_precision=*/true)
262 .Run(module.get())
263 .status());
264 auto tuple = module->entry_computation()->root_instruction();
265 EXPECT_THAT(tuple, op::Tuple(op::CustomCall(), op::AfterAll()));
266 }
267
TEST_F(HloControlFlowFlatteningTest,InfeedPreserveLayout)268 TEST_F(HloControlFlowFlatteningTest, InfeedPreserveLayout) {
269 absl::string_view hlo_string = R"(
270 HloModule Infeed
271 ENTRY Infeed {
272 after-all = token[] after-all()
273 ROOT infeed = ((bf16[3]{0}, s32[12,5]{0,1:T(8,128)}), token[]) infeed(after-all)
274 }
275 )";
276 TF_ASSERT_OK_AND_ASSIGN(auto module,
277 ParseAndReturnVerifiedModule(hlo_string));
278 Shape root_shape = module->entry_computation()->root_instruction()->shape();
279 HloControlFlowFlattening flattening(
280 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
281 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
282 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
283 /*allow_mixed_precision=*/true)
284 .Run(module.get())
285 .status());
286 auto tuple = module->entry_computation()->root_instruction();
287 EXPECT_THAT(tuple, op::Tuple(op::CustomCall(), op::AfterAll()));
288 EXPECT_EQ(tuple->shape(), root_shape);
289 }
290
TEST_F(HloControlFlowFlatteningTest,Outfeed)291 TEST_F(HloControlFlowFlatteningTest, Outfeed) {
292 absl::string_view hlo_string = R"(
293 HloModule Outfeed
294 ENTRY Outfeed {
295 param = (bf16[3]{0}, s32[12,5]{0,1}) parameter(0)
296 after-all = token[] after-all()
297 ROOT outfeed = token[] outfeed(param, after-all)
298 }
299 )";
300 TF_ASSERT_OK_AND_ASSIGN(auto module,
301 ParseAndReturnVerifiedModule(hlo_string));
302 HloControlFlowFlattening flattening(
303 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
304 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
305 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
306 /*allow_mixed_precision=*/true)
307 .Run(module.get())
308 .status());
309 auto custom_call = module->entry_computation()->root_instruction();
310 EXPECT_THAT(custom_call, op::CustomCall(op::Parameter(0), op::AfterAll()));
311 }
312
TEST_F(HloControlFlowFlatteningTest,AllReduce)313 TEST_F(HloControlFlowFlatteningTest, AllReduce) {
314 absl::string_view hlo_string = R"(
315 HloModule AllReduce
316 sum {
317 p0 = f32[] parameter(0)
318 p1 = f32[] parameter(1)
319 ROOT add = f32[] add(p0, p1)
320 }
321
322 ENTRY AllReduce {
323 param0 = f32[3]{0} parameter(0)
324 param1 = f32[12,5]{0,1} parameter(1)
325 ROOT all-reduce = (bf16[3]{0}, bf16[12,5]{0,1}) all-reduce(param0, param1), to_apply=sum, replica_groups={}
326 }
327 )";
328 TF_ASSERT_OK_AND_ASSIGN(auto module,
329 ParseAndReturnVerifiedModule(hlo_string));
330 HloControlFlowFlattening flattening(
331 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
332 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
333 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
334 /*allow_mixed_precision=*/true)
335 .Run(module.get())
336 .status());
337 LOG(INFO) << module->ToString();
338 EXPECT_THAT(module->entry_computation()->root_instruction(),
339 op::CustomCall(op::Parameter(0), op::Parameter(1)));
340 }
341
TEST_F(HloControlFlowFlatteningTest,AllReduceStartAndDone)342 TEST_F(HloControlFlowFlatteningTest, AllReduceStartAndDone) {
343 absl::string_view hlo_string = R"(
344 HloModule CRS
345
346 add {
347 lhs = f32[] parameter(0)
348 rhs = f32[] parameter(1)
349 ROOT add = f32[] add(lhs, rhs)
350 }
351
352 ENTRY CRS {
353 input = f32[8]{0} parameter(0)
354 crs = f32[8]{0} all-reduce-start(input), replica_groups={}, to_apply=add
355 ROOT done = f32[8]{0} all-reduce-done(crs)
356 }
357 )";
358 TF_ASSERT_OK_AND_ASSIGN(auto module,
359 ParseAndReturnVerifiedModule(hlo_string));
360 HloControlFlowFlattening flattening(
361 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
362 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
363 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
364 /*allow_mixed_precision=*/true)
365 .Run(module.get())
366 .status());
367 LOG(INFO) << module->ToString();
368 EXPECT_THAT(module->entry_computation()->root_instruction(),
369 op::CustomCall(op::CustomCall(op::Parameter(0))));
370 }
371
TEST_F(HloControlFlowFlatteningTest,AllGather)372 TEST_F(HloControlFlowFlatteningTest, AllGather) {
373 absl::string_view hlo_string = R"(
374 HloModule AllGather
375
376 ENTRY AllGather {
377 input = f32[128,32]{0,1} parameter(0)
378 ROOT ag = f32[128,128]{0,1} all-gather(input), replica_groups={}, dimensions={1}
379 }
380 )";
381 TF_ASSERT_OK_AND_ASSIGN(auto module,
382 ParseAndReturnVerifiedModule(hlo_string));
383 HloControlFlowFlattening flattening(
384 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
385 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
386 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
387 /*allow_mixed_precision=*/true)
388 .Run(module.get())
389 .status());
390 LOG(INFO) << module->ToString();
391 EXPECT_THAT(module->entry_computation()->root_instruction(),
392 op::CustomCall(op::Parameter(0)));
393 }
394
TEST_F(HloControlFlowFlatteningTest,AllToAll)395 TEST_F(HloControlFlowFlatteningTest, AllToAll) {
396 absl::string_view hlo_string = R"(
397 HloModule AllToAll
398
399 ENTRY AllToAll {
400 input = f32[128,32]{0,1} parameter(0)
401 ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
402 }
403 )";
404 TF_ASSERT_OK_AND_ASSIGN(auto module,
405 ParseAndReturnVerifiedModule(hlo_string));
406 HloControlFlowFlattening flattening(
407 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
408 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
409 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
410 /*allow_mixed_precision=*/true)
411 .Run(module.get())
412 .status());
413 LOG(INFO) << module->ToString();
414 EXPECT_THAT(module->entry_computation()->root_instruction(),
415 op::CustomCall(op::Parameter(0)));
416 }
417
TEST_F(HloControlFlowFlatteningTest,CollectivePermute)418 TEST_F(HloControlFlowFlatteningTest, CollectivePermute) {
419 absl::string_view hlo_string = R"(
420 HloModule CollectivePermute
421
422 ENTRY CollectivePermute {
423 input = f32[128,32]{0,1} parameter(0)
424 ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
425 }
426 )";
427 TF_ASSERT_OK_AND_ASSIGN(auto module,
428 ParseAndReturnVerifiedModule(hlo_string));
429 HloControlFlowFlattening flattening(
430 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
431 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
432 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
433 /*allow_mixed_precision=*/true)
434 .Run(module.get())
435 .status());
436 LOG(INFO) << module->ToString();
437 EXPECT_THAT(module->entry_computation()->root_instruction(),
438 op::CustomCall(op::Parameter(0)));
439 }
440
TEST_F(HloControlFlowFlatteningTest,CollectivePermuteInPlaceUpdate)441 TEST_F(HloControlFlowFlatteningTest, CollectivePermuteInPlaceUpdate) {
442 absl::string_view hlo_string = R"(
443 HloModule CollectivePermuteInPlaceUpdate
444
445 ENTRY CollectivePermuteInPlaceUpdate {
446 input = f32[128,32]{0,1} parameter(0)
447 constant = f32[] constant(1)
448 output = f32[128,128]{0,1} broadcast(constant), dimensions={}
449 constant.1 = s32[] constant(0)
450 tuple.1 = (s32[], s32[]) tuple(constant.1, constant.1)
451 constant.2 = s32[] constant(64)
452 tuple.2 = (s32[], s32[]) tuple(constant.1, constant.2)
453 ROOT root = f32[128,128]{0,1} collective-permute(input, output, tuple.1, tuple.2), source_target_pairs={{0,1},{1,2},{2,3}}, slice_sizes={{128,32}}
454 }
455 )";
456 TF_ASSERT_OK_AND_ASSIGN(auto module,
457 ParseAndReturnVerifiedModule(hlo_string));
458 HloControlFlowFlattening flattening(
459 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
460 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
461 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
462 /*allow_mixed_precision=*/true)
463 .Run(module.get())
464 .status());
465 LOG(INFO) << module->ToString();
466 EXPECT_THAT(module->entry_computation()->root_instruction(),
467 op::CustomCall(op::Parameter(0), op::Broadcast(op::Constant()),
468 op::Tuple(op::Constant(), op::Constant()),
469 op::Tuple(op::Constant(), op::Constant())));
470 }
471
TEST_F(HloControlFlowFlatteningTest,CollectivePermuteStartAndDone)472 TEST_F(HloControlFlowFlatteningTest, CollectivePermuteStartAndDone) {
473 absl::string_view hlo_string = R"(
474 HloModule CollectivePermuteStartAndDone
475
476 ENTRY CollectivePermuteStartAndDone {
477 input = f32[128,32]{0,1} parameter(0)
478 collective-permute-start.1 = (f32[128,32]{0,1}, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2},{2,3}}
479 ROOT collective-permute-done.1 = f32[128,32]{0,1} collective-permute-done(collective-permute-start.1)
480 }
481 )";
482 TF_ASSERT_OK_AND_ASSIGN(auto module,
483 ParseAndReturnVerifiedModule(hlo_string));
484 HloControlFlowFlattening flattening(
485 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
486 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
487 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
488 /*allow_mixed_precision=*/true)
489 .Run(module.get())
490 .status());
491 LOG(INFO) << module->ToString();
492 EXPECT_THAT(module->entry_computation()->root_instruction(),
493 op::CustomCall(op::CustomCall(op::Parameter(0))));
494 }
495
TEST_F(HloControlFlowFlatteningTest,Recv)496 TEST_F(HloControlFlowFlatteningTest, Recv) {
497 absl::string_view hlo_string = R"(
498 HloModule Recv
499
500 ENTRY %Recv () -> (f32[], token[]) {
501 %token0 = token[] after-all()
502 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
503 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
504 %constant = f32[] constant(2.1), sharding={maximal device=0}
505 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
506 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} }
507 )";
508 TF_ASSERT_OK_AND_ASSIGN(auto module,
509 ParseAndReturnVerifiedModule(hlo_string));
510 ControlDepRemover control_remover;
511 HloControlFlowFlattening flattening(
512 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
513 TF_ASSERT_OK(control_remover.Run(module.get()).status());
514 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
515 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
516 /*allow_mixed_precision=*/true)
517 .Run(module.get())
518 .status());
519 LOG(INFO) << module->ToString();
520 EXPECT_THAT(module->entry_computation()->root_instruction(),
521 op::Tuple(op::CustomCall(), op::AfterAll()));
522 }
523
TEST_F(HloControlFlowFlatteningTest,RecvHostTransfer)524 TEST_F(HloControlFlowFlatteningTest, RecvHostTransfer) {
525 absl::string_view hlo_string = R"(
526 HloModule Recv
527
528 ENTRY %Recv () -> (f32[], token[]) {
529 %token0 = token[] after-all()
530 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true, sharding={maximal device=1}
531 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true, sharding={maximal device=1}
532 %constant = f32[] constant(2.1), sharding={maximal device=0}
533 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
534 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0} }
535 )";
536 TF_ASSERT_OK_AND_ASSIGN(auto module,
537 ParseAndReturnVerifiedModule(hlo_string));
538 ControlDepRemover control_remover;
539 HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
540 /*while_execution_count=*/3, /*max_outer_loop_count=*/3,
541 /*max_loop_count=*/3, /*remove_infeed_outfeed=*/true,
542 /*flatten_while_loop=*/true, /*remove_comm=*/false,
543 /*remove_host_transfer=*/true});
544 TF_ASSERT_OK(control_remover.Run(module.get()).status());
545 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
546 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
547 /*allow_mixed_precision=*/true)
548 .Run(module.get())
549 .status());
550 LOG(INFO) << module->ToString();
551 EXPECT_THAT(module->entry_computation()->root_instruction(),
552 op::Tuple(op::CustomCall(), op::AfterAll()));
553 }
554
TEST_F(HloControlFlowFlatteningTest,Send)555 TEST_F(HloControlFlowFlatteningTest, Send) {
556 absl::string_view hlo_string = R"(
557 HloModule Send
558
559 ENTRY %Send () -> token[] {
560 %token0 = token[] after-all()
561 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
562 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
563 %constant = f32[] constant(2.1), sharding={maximal device=0}
564 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
565 ROOT %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
566 }
567 )";
568 TF_ASSERT_OK_AND_ASSIGN(auto module,
569 ParseAndReturnVerifiedModule(hlo_string));
570 ControlDepRemover control_remover;
571 HloControlFlowFlattening flattening(
572 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
573 TF_ASSERT_OK(control_remover.Run(module.get()).status());
574 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
575 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
576 /*allow_mixed_precision=*/true)
577 .Run(module.get())
578 .status());
579 LOG(INFO) << module->ToString();
580 EXPECT_THAT(module->entry_computation()->root_instruction(),
581 op::CustomCall(op::Constant(), op::AfterAll()));
582 }
583
TEST_F(HloControlFlowFlatteningTest,SendHostTransfer)584 TEST_F(HloControlFlowFlatteningTest, SendHostTransfer) {
585 absl::string_view hlo_string = R"(
586 HloModule Send
587
588 ENTRY %Send () -> token[] {
589 %token0 = token[] after-all()
590 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
591 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
592 %constant = f32[] constant(2.1), sharding={maximal device=0}
593 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true, sharding={maximal device=0}, control-predecessors={%recv}
594 ROOT %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true, sharding={maximal device=0}
595 }
596 )";
597 TF_ASSERT_OK_AND_ASSIGN(auto module,
598 ParseAndReturnVerifiedModule(hlo_string));
599 ControlDepRemover control_remover;
600 HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
601 /*while_execution_count=*/3, /*max_outer_loop_count=*/3,
602 /*max_loop_count=*/3, /*remove_infeed_outfeed=*/true,
603 /*flatten_while_loop=*/true, /*remove_comm=*/false,
604 /*remove_host_transfer=*/true});
605 TF_ASSERT_OK(control_remover.Run(module.get()).status());
606 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
607 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
608 /*allow_mixed_precision=*/true)
609 .Run(module.get())
610 .status());
611 LOG(INFO) << module->ToString();
612 EXPECT_THAT(module->entry_computation()->root_instruction(),
613 op::CustomCall(op::Constant(), op::AfterAll()));
614 }
615
TEST_F(HloControlFlowFlatteningTest,AllGatherStartAndDone)616 TEST_F(HloControlFlowFlatteningTest, AllGatherStartAndDone) {
617 absl::string_view hlo_string = R"(
618 HloModule AllGatherStartAndDone
619
620 ENTRY AllGatherStartAndDone {
621 %input = f32[8,256,256] parameter(0)
622 %ag-start = (f32[8,256,256], f32[16,256,256]) all-gather-start(
623 f32[8,256,256] %input), replica_groups={{0,1}}, dimensions={0},
624 metadata={op_type="AllGather" op_name="ag0"}
625 ROOT %ag-done = f32[16,256,256] all-gather-done(
626 (f32[8,256,256], f32[16,256,256]) %ag-start),
627 metadata={op_type="AllGather" op_name="ag0"}
628 }
629 )";
630 TF_ASSERT_OK_AND_ASSIGN(auto module,
631 ParseAndReturnVerifiedModule(hlo_string));
632 HloControlFlowFlattening flattening(
633 HloControlFlowFlattening::Options{/*while_execution_count=*/3});
634 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
635 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
636 /*allow_mixed_precision=*/true)
637 .Run(module.get())
638 .status());
639 LOG(INFO) << module->ToString();
640 EXPECT_THAT(module->entry_computation()->root_instruction(),
641 op::CustomCall(op::CustomCall(op::Parameter(0))));
642 }
643
TEST_F(HloControlFlowFlatteningTest,CollectiveFusion)644 TEST_F(HloControlFlowFlatteningTest, CollectiveFusion) {
645 absl::string_view hlo_template = R"(
646 HloModule collective-fusion, is_scheduled=true
647
648 %sum (a: f32[], b: f32[]) -> f32[] {
649 %a = f32[] parameter(0)
650 %b = f32[] parameter(1)
651 ROOT %add = f32[] add(f32[] a, f32[] b)
652 }
653
654 %all-gather {
655 %constant.3 = f32[] constant(0)
656 %broadcast = f32[full_size,8,128]{2,1,0} broadcast(%constant.3), dimensions={}
657 %input.0 = f32[4,8,128]{2,1,0} parameter(0)
658 %input.1 = f32[4,8,128]{2,1,0} parameter(1)
659 %replica-id.1 = u32[] replica-id()
660 %constant.4 = u32[] constant(4)
661 %multiply.1 = u32[] multiply(%replica-id.1, %constant.4)
662 %constant.5 = u32[] constant(0)
663 %constant.6 = u32[] constant(0)
664 %dynamic-update-slice = f32[full_size,8,128]{2,1,0} dynamic-update-slice(%broadcast, %input.0, %multiply.1, %constant.5, %constant.6)
665 %dynamic-update-slice.1 = f32[full_size,8,128]{2,1,0} dynamic-update-slice(%broadcast, %input.1, %multiply.1, %constant.5, %constant.6)
666 %all-reduce = (f32[full_size,8,128]{2,1,0}, f32[full_size,8,128]{2,1,0}) all-reduce(%dynamic-update-slice, %dynamic-update-slice.1), replica_groups={}, backend_config="{barrier_config:{barrier_type:3,id:0}}", to_apply=%sum
667 %gte0 = f32[full_size,8,128]{2,1,0} get-tuple-element(%all-reduce), index=0
668 %slice = f32[unpadded_size,8,128]{2,1,0} slice(%gte0), slice={[0:unpadded_size], [0:8], [0:128]}
669 %bitcast = f32[unpadded_size,1,8,128]{3,2,1,0} bitcast(%slice)
670 %gte1 = f32[full_size,8,128]{2,1,0} get-tuple-element(%all-reduce), index=1
671 ROOT %tuple = (f32[unpadded_size,1,8,128]{3,2,1,0}, f32[full_size,8,128]{2,1,0}) tuple(%bitcast, %gte1)
672 }
673
674 ENTRY main {
675 %add.1 = f32[4,8,128]{2,1,0} parameter(0)
676 %add.2 = f32[4,8,128]{2,1,0} parameter(1)
677 ROOT %fusion = (f32[unpadded_size,1,8,128]{3,2,1,0}, f32[full_size,8,128]{2,1,0}) fusion(%add.1, %add.2), kind=kCustom, calls=%all-gather
678 }
679 )";
680 auto hlo_string = absl::StrReplaceAll(
681 hlo_template, {{"full_size", absl::StrCat(12288)},
682 {"unpadded_size", absl::StrCat(12285)}});
683 TF_ASSERT_OK_AND_ASSIGN(auto module,
684 ParseAndReturnVerifiedModule(hlo_string));
685 EXPECT_TRUE(IsCollective(module->entry_computation()->root_instruction()));
686
687 HloControlFlowFlattening flattening({});
688 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
689 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
690 /*allow_mixed_precision=*/true)
691 .Run(module.get())
692 .status());
693 LOG(INFO) << module->ToString();
694 EXPECT_THAT(module->entry_computation()->root_instruction(),
695 op::CustomCall(op::Parameter(0), op::Parameter(1)));
696 }
697
CheckWhileBound(HloInstruction * while_op,int expected_bound)698 void CheckWhileBound(HloInstruction* while_op, int expected_bound) {
699 auto* cond = while_op->while_condition();
700 ASSERT_NE(cond, nullptr);
701 auto* hlo_bound = cond->root_instruction()->operand(1);
702 EXPECT_TRUE(hlo_bound->IsConstant());
703 if (hlo_bound->IsConstant()) {
704 EXPECT_TRUE(hlo_bound->literal().IsAll(expected_bound));
705 }
706 }
707
TEST_F(HloControlFlowFlatteningTest,MaxOuterLoopCount)708 TEST_F(HloControlFlowFlatteningTest, MaxOuterLoopCount) {
709 absl::string_view hlo_string = R"(
710 HloModule NestedWhileComp
711
712 InnerBody {
713 constant.8 = pred[] constant(false)
714 parameter.5 = (s32[], s32[]) parameter(0)
715 get-tuple-element.6 = s32[] get-tuple-element(parameter.5), index=0
716 constant.9 = s32[] constant(1)
717 add.10 = s32[] add(get-tuple-element.6, constant.9)
718 get-tuple-element.7 = s32[] get-tuple-element(parameter.5), index=1
719 constant.11 = s32[] constant(1)
720 add.12 = s32[] add(get-tuple-element.7, constant.11)
721 ROOT tuple.13 = (s32[], s32[]) tuple(add.10, add.12)
722 }
723
724 InnerCond {
725 parameter.15 = (s32[], s32[]) parameter(0)
726 get-tuple-element.17 = s32[] get-tuple-element(parameter.15), index=1
727 constant.18 = pred[] constant(false)
728 get-tuple-element.16 = s32[] get-tuple-element(parameter.15), index=0
729 inner_bound = s32[] constant(100)
730 ROOT compare.20 = pred[] compare(get-tuple-element.16, inner_bound), direction=LT
731 }
732
733 OuterBody {
734 constant.24 = pred[] constant(false)
735 constant.25 = s32[] constant(0)
736 parameter.22 = (s32[]) parameter(0)
737 get-tuple-element.23 = s32[] get-tuple-element(parameter.22), index=0
738 tuple.26 = (s32[], s32[]) tuple(constant.25, get-tuple-element.23)
739 inner_while = (s32[], s32[]) while(tuple.26), condition=InnerCond, body=InnerBody
740 get-tuple-element.28 = s32[] get-tuple-element(inner_while), index=0
741 get-tuple-element.29 = s32[] get-tuple-element(inner_while), index=1
742 tuple.30 = (s32[], s32[]) tuple(get-tuple-element.28, get-tuple-element.29)
743 get-tuple-element.31 = s32[] get-tuple-element(tuple.30), index=0
744 get-tuple-element.32 = s32[] get-tuple-element(tuple.30), index=1
745 ROOT tuple.33 = (s32[]) tuple(get-tuple-element.32)
746 }
747
748 OuterCond {
749 constant.37 = pred[] constant(false)
750 parameter.35 = (s32[]) parameter(0)
751 get-tuple-element.36 = s32[] get-tuple-element(parameter.35), index=0
752 outer_bound = s32[] constant(1000)
753 ROOT compare.39 = pred[] compare(get-tuple-element.36, outer_bound), direction=LT
754 }
755
756 ENTRY NestedWhileComp {
757 constant.1 = pred[] constant(false)
758 constant.2 = s32[] constant(0)
759 tuple.3 = (s32[]) tuple(constant.2)
760 outer_while = (s32[]) while(tuple.3), condition=OuterCond, body=OuterBody
761 get-tuple-element.41 = s32[] get-tuple-element(outer_while), index=0
762 tuple.42 = (s32[]) tuple(get-tuple-element.41)
763 get-tuple-element.43 = s32[] get-tuple-element(tuple.42), index=0
764 ROOT tuple.44 = (s32[]) tuple(get-tuple-element.43)
765 }
766 )";
767
768 TF_ASSERT_OK_AND_ASSIGN(auto module,
769 ParseAndReturnVerifiedModule(hlo_string));
770 constexpr int kWhileExecutionCount = 5;
771 constexpr int kExistingInnerLoopCount = 100;
772 constexpr int kMaxLoopCount = 10;
773 HloControlFlowFlattening flattening(HloControlFlowFlattening::Options{
774 /*while_execution_count=*/kWhileExecutionCount,
775 /*max_outer_loop_count=*/kMaxLoopCount});
776 EXPECT_TRUE(flattening.Run(module.get()).ValueOrDie());
777 TF_ASSERT_OK(HloVerifier(/*layout_sensitive=*/true,
778 /*allow_mixed_precision=*/true)
779 .Run(module.get())
780 .status());
781 LOG(INFO) << module->ToString();
782
783 auto* outer_while =
784 module->entry_computation()->GetInstructionWithName("outer_while");
785 ASSERT_NE(outer_while, nullptr);
786 // Checks that the outer while loop has changed its loop bound.
787 CheckWhileBound(outer_while, kMaxLoopCount);
788 auto* while_body = outer_while->while_body();
789 ASSERT_NE(while_body, nullptr);
790
791 auto* inner_while = while_body->GetInstructionWithName("inner_while");
792 ASSERT_NE(inner_while, nullptr);
793 // Checks that the inner loop bound has not changed.
794 CheckWhileBound(inner_while, kExistingInnerLoopCount);
795 }
796
TEST_F(HloControlFlowFlatteningTest,MatchLtUseInferedLoopCount)797 TEST_F(HloControlFlowFlatteningTest, MatchLtUseInferedLoopCount) {
798 absl::string_view hlo_string = R"(
799 HloModule While
800 While.body {
801 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
802 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
803 constant.1 = s32[] constant(1)
804 add = s32[] add(get-tuple-element.1, constant.1)
805 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
806 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
807 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
808 }
809 While.condition {
810 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
811 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
812 constant.2 = s32[] constant(100)
813 ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
814 }
815 ENTRY While {
816 constant.3 = s32[] constant(42)
817 constant.4 = s32[3]{0} constant({0, 1, 2})
818 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
819 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
820 }
821 )";
822 TF_ASSERT_OK_AND_ASSIGN(auto module,
823 ParseAndReturnVerifiedModule(hlo_string));
824
825 EXPECT_EQ(GetLoopBound(*module->entry_computation()->root_instruction(), 123,
826 kDefaultMaxLoopCount),
827 100);
828 }
829
TEST_F(HloControlFlowFlatteningTest,MatchGtUseInferedLoopCount)830 TEST_F(HloControlFlowFlatteningTest, MatchGtUseInferedLoopCount) {
831 absl::string_view hlo_string = R"(
832 HloModule While
833 While.body {
834 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
835 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
836 constant.1 = s32[] constant(1)
837 add = s32[] add(get-tuple-element.1, constant.1)
838 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
839 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
840 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
841 }
842 While.condition {
843 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
844 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
845 constant.2 = s32[] constant(50)
846 ROOT greater-than = pred[] compare(constant.2, get-tuple-element.3), direction=GT
847 }
848 ENTRY While {
849 constant.3 = s32[] constant(42)
850 constant.4 = s32[3]{0} constant({0, 1, 2})
851 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
852 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
853 }
854 )";
855 TF_ASSERT_OK_AND_ASSIGN(auto module,
856 ParseAndReturnVerifiedModule(hlo_string));
857
858 EXPECT_EQ(GetLoopBound(*module->entry_computation()->root_instruction(), 123,
859 kDefaultMaxLoopCount),
860 50);
861 }
862
TEST_F(HloControlFlowFlatteningTest,NotMatchEqUseDefaultLoopCount)863 TEST_F(HloControlFlowFlatteningTest, NotMatchEqUseDefaultLoopCount) {
864 absl::string_view hlo_string = R"(
865 HloModule While
866 While.body {
867 loop_var.1 = (s32[], s32[3]{0}) parameter(0)
868 get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
869 constant.1 = s32[] constant(1)
870 add = s32[] add(get-tuple-element.1, constant.1)
871 get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
872 multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
873 ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
874 }
875 While.condition {
876 loop_var.2 = (s32[], s32[3]{0}) parameter(0)
877 get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
878 constant.2 = s32[] constant(100)
879 ROOT equal = pred[] compare(get-tuple-element.3, constant.2), direction=EQ
880 }
881 ENTRY While {
882 constant.3 = s32[] constant(42)
883 constant.4 = s32[3]{0} constant({0, 1, 2})
884 tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
885 ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=While.condition, body=While.body
886 }
887 )";
888 TF_ASSERT_OK_AND_ASSIGN(auto module,
889 ParseAndReturnVerifiedModule(hlo_string));
890
891 EXPECT_EQ(GetLoopBound(*module->entry_computation()->root_instruction(), 123,
892 kDefaultMaxLoopCount),
893 123);
894 }
895
896 } // namespace
897 } // namespace xla
898