1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
29 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
30 #include "tensorflow/compiler/xla/service/test_compilation_environment.pb.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36
37 namespace xla {
38
39 // In order to use TestCompilationEnvironment* with CompilationEnvironments, we
40 // must define CreateDefaultEnv for them.
41 template <>
42 std::unique_ptr<test::TestCompilationEnvironment1>
CreateDefaultEnv()43 CompilationEnvironments::CreateDefaultEnv<test::TestCompilationEnvironment1>() {
44 auto env = std::make_unique<test::TestCompilationEnvironment1>();
45 env->set_some_flag(100);
46 return env;
47 }
48
49 namespace {
50
51 namespace op = ::xla::testing::opcode_matchers;
52
53 class HloModuleTest : public HloTestBase {
54 protected:
HloModuleTest()55 HloModuleTest() {}
56
57 // Create a computation which returns a constant.
CreateConstantComputation()58 std::unique_ptr<HloComputation> CreateConstantComputation() {
59 auto builder = HloComputation::Builder("Constant");
60 builder.AddInstruction(
61 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
62 return builder.Build();
63 }
64
65 // Creates a computation which calls the given zero-parameter computations.
CreateCallComputation(absl::Span<HloComputation * const> computations)66 std::unique_ptr<HloComputation> CreateCallComputation(
67 absl::Span<HloComputation* const> computations) {
68 auto builder = HloComputation::Builder("Call");
69 for (auto computation : computations) {
70 builder.AddInstruction(
71 HloInstruction::CreateCall(r0f32_, {}, computation));
72 }
73 return builder.Build();
74 }
75
76 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
77 };
78
TEST_F(HloModuleTest,OneComputationPostOrder)79 TEST_F(HloModuleTest, OneComputationPostOrder) {
80 // Create a module with a single computation.
81 auto module = CreateNewVerifiedModule();
82 auto computation = module->AddEntryComputation(CreateConstantComputation());
83
84 EXPECT_THAT(module->MakeComputationPostOrder(),
85 ::testing::ElementsAre(computation));
86 }
87
TEST_F(HloModuleTest,TwoComputationsPostOrder)88 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
89 // Create a module with two unconnected computations.
90 auto module = CreateNewVerifiedModule();
91 auto computation1 = module->AddEntryComputation(CreateConstantComputation());
92 auto computation2 =
93 module->AddEmbeddedComputation(CreateConstantComputation());
94
95 EXPECT_THAT(module->MakeComputationPostOrder(),
96 ::testing::UnorderedElementsAre(computation1, computation2));
97
98 // We specified the same name for both computations, but the HloModule should
99 // have made the names unique.
100 EXPECT_EQ(computation1->name(), "Constant");
101 EXPECT_EQ(computation2->name(), "Constant.1");
102 }
103
TEST_F(HloModuleTest,CloneTest)104 TEST_F(HloModuleTest, CloneTest) {
105 // Create and copy a module with a diamond call graph of computations.
106 auto module = CreateNewVerifiedModule();
107 auto computation1 =
108 module->AddEmbeddedComputation(CreateConstantComputation());
109 auto computation2 =
110 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
111 auto computation3 =
112 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
113 module->AddEntryComputation(
114 CreateCallComputation({computation2, computation3}));
115 // Add a compilation environment to module
116 auto env = std::make_unique<test::TestCompilationEnvironment1>();
117 env->set_some_flag(10);
118 module->comp_envs().AddEnv(std::move(env));
119
120 auto post_order = module->MakeComputationPostOrder();
121 auto cloned_module = module->Clone("copy");
122 auto post_order_copied = cloned_module->MakeComputationPostOrder();
123
124 // Make sure module's CompilationEnvironments were copied to cloned_module
125 EXPECT_EQ(cloned_module->comp_envs()
126 .GetEnv<test::TestCompilationEnvironment1>()
127 .some_flag(),
128 10);
129
130 EXPECT_EQ(post_order.size(), post_order_copied.size());
131 for (auto origin = post_order.begin(), copied = post_order_copied.begin();
132 origin != post_order.end() && copied != post_order_copied.end();
133 ++origin, ++copied) {
134 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
135 }
136 }
137
TEST_F(HloModuleTest,CloneHasFusion)138 TEST_F(HloModuleTest, CloneHasFusion) {
139 auto module = CreateNewVerifiedModule();
140
141 // Create the fused computation.
142 HloComputation* fused_computation;
143 {
144 auto b = HloComputation::Builder("Fused");
145 auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
146 b.AddInstruction(
147 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
148 fused_computation = module->AddEmbeddedComputation(b.Build());
149 }
150
151 // Create the entry computation.
152 {
153 auto b = HloComputation::Builder("Entry");
154 auto input = b.AddInstruction(
155 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
156 b.AddInstruction(
157 HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
158 /*operands=*/{input}, fused_computation));
159 module->AddEntryComputation(b.Build());
160 }
161
162 auto post_order = module->MakeComputationPostOrder();
163 auto cloned_module = module->Clone("copy");
164 auto post_order_copied = cloned_module->MakeComputationPostOrder();
165
166 EXPECT_EQ(post_order.size(), post_order_copied.size());
167 for (auto origin = post_order.begin(), copied = post_order_copied.begin();
168 origin != post_order.end() && copied != post_order_copied.end();
169 ++origin, ++copied) {
170 if ((*origin)->name() == "Fused") {
171 // Clone of the fused computation is handled when its fusion instruction
172 // is cloned, which always use suffix ".clone".
173 EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
174 } else {
175 EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
176 }
177 }
178 }
179
TEST_F(HloModuleTest,CloneCustomCallComputationToApply)180 TEST_F(HloModuleTest, CloneCustomCallComputationToApply) {
181 const char* const hlo_string = R"(
182 HloModule a_module
183
184 add_s32 {
185 lhs = s32[] parameter(0)
186 rhs = s32[] parameter(1)
187 ROOT add = s32[] add(lhs, rhs)
188 }
189
190 ENTRY entry () -> s32[] {
191 %c1 = s32[] constant(1)
192 %c2 = s32[] constant(2)
193 ROOT %custom-call =
194 s32[] custom-call(s32[] %c1, %c2),
195 custom_call_target="foo",
196 backend_config="this string is opaque",
197 to_apply=add_s32
198 })";
199 TF_ASSERT_OK_AND_ASSIGN(auto module,
200 ParseAndReturnVerifiedModule(hlo_string));
201
202 std::unique_ptr<HloModule> cloned_module = module->Clone();
203 HloComputation* cloned_computation =
204 cloned_module->GetComputationWithName("add_s32.clone");
205 HloInstruction* cloned_custom_call =
206 cloned_module->entry_computation()->GetInstructionWithName("custom-call");
207
208 EXPECT_TRUE(cloned_computation->IsCustomCallComputation());
209 EXPECT_EQ(cloned_computation->CustomCallInstruction(), cloned_custom_call);
210 }
211
TEST_F(HloModuleTest,CloneCustomCallComputationCalledComputations)212 TEST_F(HloModuleTest, CloneCustomCallComputationCalledComputations) {
213 const char* const hlo_string = R"(
214 HloModule a_module
215
216 add_s32_0 {
217 lhs = s32[] parameter(0)
218 rhs = s32[] parameter(1)
219 ROOT add = s32[] add(lhs, rhs)
220 }
221
222 add_s32_1 {
223 lhs = s32[] parameter(0)
224 rhs = s32[] parameter(1)
225 ROOT add = s32[] add(lhs, rhs)
226 }
227
228 ENTRY entry () -> s32[] {
229 %c1 = s32[] constant(1)
230 %c2 = s32[] constant(2)
231 ROOT %custom-call =
232 s32[] custom-call(s32[] %c1, %c2),
233 custom_call_target="foo",
234 backend_config="this string is opaque",
235 called_computations={%add_s32_0, %add_s32_1}
236 })";
237 TF_ASSERT_OK_AND_ASSIGN(auto module,
238 ParseAndReturnVerifiedModule(hlo_string));
239
240 std::unique_ptr<HloModule> cloned_module = module->Clone();
241 HloComputation* cloned_computation_0 =
242 cloned_module->GetComputationWithName("add_s32_0.clone");
243 HloComputation* cloned_computation_1 =
244 cloned_module->GetComputationWithName("add_s32_1.clone");
245 HloInstruction* cloned_custom_call =
246 cloned_module->entry_computation()->GetInstructionWithName("custom-call");
247
248 EXPECT_TRUE(cloned_computation_0->IsCustomCallComputation());
249 EXPECT_EQ(cloned_computation_0->CustomCallInstruction(), cloned_custom_call);
250 EXPECT_TRUE(cloned_computation_1->IsCustomCallComputation());
251 EXPECT_EQ(cloned_computation_1->CustomCallInstruction(), cloned_custom_call);
252 }
253
TEST_F(HloModuleTest,CloneFusionComputation)254 TEST_F(HloModuleTest, CloneFusionComputation) {
255 const char* const hlo_string = R"(
256 HloModule a_module
257
258 fused_computation () -> s32[] {
259 ROOT %result = s32[] parameter(0)
260 }
261
262 ENTRY main {
263 %c = s32[] constant(1)
264 ROOT %fusion = s32[] fusion(%c), kind=kLoop, calls=fused_computation
265 }
266 )";
267 TF_ASSERT_OK_AND_ASSIGN(auto module,
268 ParseAndReturnVerifiedModule(hlo_string));
269
270 std::unique_ptr<HloModule> cloned_module = module->Clone();
271 HloComputation* cloned_computation =
272 cloned_module->GetComputationWithName("fused_computation.clone");
273 HloInstruction* cloned_fusion_instr =
274 cloned_module->entry_computation()->GetInstructionWithName("fusion");
275
276 EXPECT_TRUE(cloned_computation->IsFusionComputation());
277 EXPECT_EQ(cloned_computation->FusionInstruction(), cloned_fusion_instr);
278 }
279
TEST_F(HloModuleTest,DiamondComputationsPostOrder)280 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
281 // Create a module with a diamond call graph of computations.
282 auto module = CreateNewVerifiedModule();
283 auto computation1 =
284 module->AddEmbeddedComputation(CreateConstantComputation());
285 auto computation2 =
286 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
287 auto computation3 =
288 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
289 auto computation4 = module->AddEntryComputation(
290 CreateCallComputation({computation2, computation3}));
291
292 auto post_order = module->MakeComputationPostOrder();
293 EXPECT_THAT(post_order,
294 ::testing::UnorderedElementsAre(computation1, computation2,
295 computation3, computation4));
296 EXPECT_EQ(post_order.back(), computation4);
297 EXPECT_EQ(post_order.front(), computation1);
298 }
299
TEST_F(HloModuleTest,LargeConstantToString)300 TEST_F(HloModuleTest, LargeConstantToString) {
301 // Create a module with a single computation.
302 auto module = CreateNewVerifiedModule();
303 auto builder = HloComputation::Builder("Constant");
304 std::vector<float> values(16, 42.0);
305 builder.AddInstruction(
306 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
307 module->AddEntryComputation(builder.Build());
308
309 EXPECT_EQ(
310 "HloModule LargeConstantToString, "
311 "entry_computation_layout={()->f32[16]{0}}\n\nENTRY %Constant () -> "
312 "f32[16] {\n ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
313 module->ToString(HloPrintOptions().set_print_large_constants(false)));
314
315 EXPECT_EQ(
316 "HloModule LargeConstantToString, "
317 "entry_computation_layout={()->f32[16]{0}}\n\nENTRY %Constant () -> "
318 "f32[16] {\n ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, "
319 "42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
320 module->ToString(HloPrintOptions().set_print_large_constants(true)));
321 }
322
TEST_F(HloModuleTest,UniqueModuleId)323 TEST_F(HloModuleTest, UniqueModuleId) {
324 auto module_a = CreateNewVerifiedModule();
325 auto module_b = CreateNewVerifiedModule();
326 EXPECT_NE(module_a->unique_id(), module_b->unique_id());
327 }
328
TEST_F(HloModuleTest,ProtoSerializationWithoutSchedule)329 TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
330 const std::string text = R"(
331 HloModule axpy_module
332
333 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
334 %alpha = f32[] parameter(0)
335 %x = f32[2,4]{1,0} parameter(1)
336 %y = f32[2,4]{1,0} parameter(2)
337 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
338 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
339 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
340 }
341 )";
342 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
343 ASSERT_FALSE(module->has_schedule());
344 TF_ASSERT_OK_AND_ASSIGN(
345 auto module_copy,
346 HloModule::CreateFromProto(module->ToProto(), module->config()));
347 ASSERT_FALSE(module_copy->has_schedule());
348 }
349
TEST_F(HloModuleTest,ProtoSerializationWithSchedule)350 TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
351 const std::string text = R"(
352 HloModule axpy_module, is_scheduled=true
353
354 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
355 %alpha = f32[] parameter(0)
356 %x = f32[2,4]{1,0} parameter(1)
357 %y = f32[2,4]{1,0} parameter(2)
358 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
359 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
360 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
361 }
362 )";
363 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
364 ASSERT_TRUE(module->has_schedule());
365 TF_ASSERT_OK_AND_ASSIGN(
366 auto module_copy,
367 HloModule::CreateFromProto(module->ToProto(), module->config()));
368 ASSERT_TRUE(module_copy->has_schedule());
369 TF_ASSERT_OK(module_copy->schedule().Verify());
370 EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
371 ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
372 module_copy->entry_computation()));
373 EXPECT_THAT(
374 module_copy->schedule()
375 .sequence(module_copy->entry_computation())
376 .instructions(),
377 ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
378 op::Broadcast(), op::Multiply(), op::Add()));
379 }
380
TEST_F(HloModuleTest,ProtoSerializationPreservesIds)381 TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
382 // Verify that serializing then deserializing an HLO proto preserves the
383 // unique IDs of the instruction and module.
384 const std::string text =
385 R"(HloModule ReduceR3ToR2_module
386
387 add_F32.v3 {
388 lhs = f32[] parameter(0)
389 rhs = f32[] parameter(1)
390 ROOT add = f32[] add(lhs, rhs)
391 }
392
393 ENTRY ReduceR3ToR2.v3 {
394 input = f32[8,16,256]{2,1,0} parameter(0)
395 constant = f32[] constant(0)
396 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
397 }
398 )";
399 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
400
401 // Perform various transformations on the graph:
402 //
403 // * clone the reduction function
404 // * replace use of reduction function with the clone.
405 // * add a random instruction to the entry computation.
406 //
407 // This will create instruction and computation IDs which are interesting:
408 // not consecutive and not densely packed.
409 HloComputation* entry = module->entry_computation();
410 HloInstruction* root = entry->root_instruction();
411 HloComputation* reduction = root->to_apply();
412 HloComputation* reduction_clone =
413 module->AddEmbeddedComputation(reduction->Clone());
414 root->set_to_apply(reduction_clone);
415 TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
416 HloInstruction* negate = entry->AddInstruction(
417 HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
418 entry->set_root_instruction(negate);
419
420 // Schedule the transformed module, this verifies that the serialized schedule
421 // is robust against non-consecutive IDs as well (b/114712358).
422 auto size_fn = [](const BufferValue& buffer) {
423 return ShapeUtil::ByteSizeOf(buffer.shape());
424 };
425 HloMemoryScheduler scheduler(size_fn);
426 TF_ASSERT_OK(scheduler.Run(module.get()).status());
427 ASSERT_TRUE(module->has_schedule());
428
429 // Serialize and deserialize and verify that the instruction and computations
430 // unique ids are the same.
431 TF_ASSERT_OK_AND_ASSIGN(
432 auto module_copy,
433 HloModule::CreateFromProto(module->ToProto(), module->config()));
434
435 // The module IDs should *not* be the same because module ids must be globally
436 // unique.
437 EXPECT_NE(module->unique_id(), module_copy->unique_id());
438
439 // Verify that the computations and instructions all have the same unique id.
440 auto computation_copy = module_copy->computations();
441 auto computation_copy_it = computation_copy.begin();
442 for (const HloComputation* computation_orig : module->computations()) {
443 const HloComputation* computation_copy = *computation_copy_it++;
444 EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
445 << absl::StrFormat(
446 "ID of original computation %s != ID of deserialized "
447 "computation %s: %d != %d",
448 computation_orig->name(), computation_copy->name(),
449 computation_orig->unique_id(), computation_copy->unique_id());
450
451 auto instruction_copy_it = computation_copy->instructions().begin();
452 for (const HloInstruction* instruction_orig :
453 computation_orig->instructions()) {
454 const HloInstruction* instruction_copy = *instruction_copy_it++;
455 EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
456 << absl::StrFormat(
457 "ID of original instruction %s != ID of deserialized "
458 "instruction %s: %d != %d",
459 instruction_orig->name(), instruction_copy->name(),
460 instruction_orig->unique_id(), instruction_copy->unique_id());
461 }
462 }
463
464 // Verify that the next unique ID which the module would have handed out is
465 // greater than the unique id of any instruction.
466 int next_id = module_copy->NewUniqueInstructionId();
467 for (const HloComputation* computation : module_copy->computations()) {
468 for (const HloInstruction* instruction : computation->instructions()) {
469 EXPECT_GT(next_id, instruction->unique_id());
470 }
471 }
472 }
473
TEST_F(HloModuleTest,VerifyReplaceComputationsWithSortOp)474 TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) {
475 const std::string text = R"(
476 HloModule sort
477
478 compare {
479 p.0.lhs = f32[] parameter(0)
480 p.0.rhs = f32[] parameter(1)
481 p.1.lhs = f32[] parameter(2)
482 p.1.rhs = f32[] parameter(3)
483 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
484 }
485
486 ENTRY top {
487 p.0 = f32[32] parameter(0)
488 p.1 = f32[32] parameter(1)
489 ROOT %sort.148.1589 = (f32[32], f32[32]) sort(p.0, p.1), dimensions={0}, to_apply=compare
490 }
491 )";
492
493 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
494
495 // Create a replacement computation
496 HloComputation* new_comp;
497 {
498 auto b = HloComputation::Builder("Fused");
499 auto p0 =
500 b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p0"));
501 auto p1 =
502 b.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "p1"));
503 b.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "p2"));
504 b.AddInstruction(HloInstruction::CreateParameter(3, r0f32_, "p3"));
505 b.AddInstruction(HloInstruction::CreateCompare(
506 ShapeUtil::MakeShape(PRED, {}), p0, p1, ComparisonDirection::kGt));
507 new_comp = module->AddEmbeddedComputation(b.Build());
508 }
509
510 HloComputation* entry = module->entry_computation();
511 HloInstruction* root = entry->root_instruction();
512 EXPECT_EQ(root->to_apply()->root_instruction()->opcode(),
513 HloOpcode::kCompare);
514 EXPECT_EQ(root->to_apply()->root_instruction()->comparison_direction(),
515 ComparisonDirection::kLt);
516
517 absl::flat_hash_map<HloComputation*, HloComputation*> replacement;
518 replacement[root->to_apply()] = new_comp;
519 module->ReplaceComputations(replacement);
520
521 EXPECT_EQ(root->to_apply(), new_comp);
522 }
523
TEST_F(HloModuleTest,OneComputationAllAllowed)524 TEST_F(HloModuleTest, OneComputationAllAllowed) {
525 // Create a module with a single computation and
526 // ensure it is available when placed in the allow-list
527 auto module = CreateNewVerifiedModule();
528 auto computation = module->AddEntryComputation(CreateConstantComputation());
529
530 absl::flat_hash_set<HloComputation*> allowList = {computation};
531 EXPECT_THAT(
532 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList),
533 ::testing::ElementsAre(computation));
534 }
535
TEST_F(HloModuleTest,OneComputationAllFiltered)536 TEST_F(HloModuleTest, OneComputationAllFiltered) {
537 // Create a module with a single computation.
538 auto module = CreateNewVerifiedModule();
539 module->AddEntryComputation(CreateConstantComputation());
540
541 absl::flat_hash_set<HloComputation*> allowList = {};
542 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
543 EXPECT_THAT(
544 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList),
545 ::testing::IsEmpty());
546 }
547
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllAllowed)548 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllAllowed) {
549 // Create a module with a diamond call graph of computations.
550 auto module = CreateNewVerifiedModule();
551 auto computation1 =
552 module->AddEmbeddedComputation(CreateConstantComputation());
553 auto computation2 =
554 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
555 auto computation3 =
556 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
557 auto computation4 = module->AddEntryComputation(
558 CreateCallComputation({computation2, computation3}));
559
560 absl::flat_hash_set<HloComputation*> allowList = {computation1, computation2,
561 computation3, computation4};
562 auto post_order =
563 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
564 EXPECT_THAT(post_order,
565 ::testing::UnorderedElementsAre(computation1, computation2,
566 computation3, computation4));
567 EXPECT_EQ(post_order.back(), computation4);
568 EXPECT_EQ(post_order.front(), computation1);
569 }
570
TEST_F(HloModuleTest,DiamondComputationsPostOrderMiddleFiltered)571 TEST_F(HloModuleTest, DiamondComputationsPostOrderMiddleFiltered) {
572 // Create a module with a diamond call graph of computations.
573 auto module = CreateNewVerifiedModule();
574 auto computation1 =
575 module->AddEmbeddedComputation(CreateConstantComputation());
576 auto computation2 =
577 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
578 auto computation3 =
579 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
580 auto computation4 = module->AddEntryComputation(
581 CreateCallComputation({computation2, computation3}));
582
583 absl::flat_hash_set<HloComputation*> allowList = {computation1, computation4};
584 auto post_order =
585 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
586 EXPECT_THAT(post_order,
587 ::testing::UnorderedElementsAre(computation1, computation4));
588 }
589
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllFiltered)590 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllFiltered) {
591 // Create a module with a diamond call graph of computations.
592 auto module = CreateNewVerifiedModule();
593 auto computation1 =
594 module->AddEmbeddedComputation(CreateConstantComputation());
595 auto computation2 =
596 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
597 auto computation3 =
598 module->AddEmbeddedComputation(CreateCallComputation({computation1}));
599 module->AddEntryComputation(
600 CreateCallComputation({computation2, computation3}));
601
602 absl::flat_hash_set<HloComputation*> allowList = {};
603 auto post_order =
604 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
605 EXPECT_THAT(
606 module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList),
607 ::testing::IsEmpty());
608 }
609
TEST_F(HloModuleTest,TwoComputationsFilterexecution_threads)610 TEST_F(HloModuleTest, TwoComputationsFilterexecution_threads) {
611 // Create a module with two computations with different execution_threads and
612 // ensure thread name filtering can return proper computations.
613 HloComputation::Builder builder(TestName());
614 constexpr char kParallelThreadName[] = "parallel_thread";
615 // Create a call instruction containing a single binary operation.
616 auto constant1 = builder.AddInstruction(
617 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
618 auto constant2 = builder.AddInstruction(
619 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
620 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
621 r0f32_, HloOpcode::kAdd, constant1, constant2));
622 auto module = CreateNewVerifiedModule();
623 auto* main_thread_computation = module->AddEntryComputation(builder.Build());
624 TF_ASSERT_OK_AND_ASSIGN(
625 auto* async_done,
626 main_thread_computation->CreateAsyncInstructions(
627 add, {ShapeUtil::MakeScalarShape(U32)}, kParallelThreadName));
628 auto* parallel_thread_computation = async_done->async_wrapped_computation();
629
630 EXPECT_THAT(
631 module->MakeComputationPostOrder({HloInstruction::kMainExecutionThread}),
632 ::testing::ElementsAre(main_thread_computation));
633 EXPECT_THAT(module->MakeComputationPostOrder(),
634 ::testing::ElementsAre(parallel_thread_computation,
635 main_thread_computation));
636 EXPECT_THAT(module->MakeComputationPostOrder({kParallelThreadName}),
637 ::testing::ElementsAre(parallel_thread_computation));
638 // Test that computations(execution_thread) return the expected values.
639 int num_all_computations = 0;
640 for ([[maybe_unused]] const HloComputation* comp :
641 module->computations(/*execution_threads=*/{})) {
642 ++num_all_computations;
643 }
644 EXPECT_EQ(num_all_computations, 2);
645 int num_main_computations = 0;
646 for (const HloComputation* comp :
647 module->computations({HloInstruction::kMainExecutionThread})) {
648 ++num_main_computations;
649 EXPECT_EQ(comp->execution_thread(), HloInstruction::kMainExecutionThread);
650 }
651 EXPECT_EQ(num_main_computations, 1);
652 int num_parallel_computations = 0;
653 for (const HloComputation* comp :
654 module->computations({kParallelThreadName})) {
655 ++num_parallel_computations;
656 EXPECT_EQ(comp->execution_thread(), kParallelThreadName);
657 }
658 EXPECT_EQ(num_parallel_computations, 1);
659 }
660
661 } // namespace
662
663 } // namespace xla
664