xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_module_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <memory>
19 #include <optional>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
29 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
30 #include "tensorflow/compiler/xla/service/test_compilation_environment.pb.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/test.h"
33 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 
37 namespace xla {
38 
39 // In order to use TestCompilationEnvironment* with CompilationEnvironments, we
40 // must define CreateDefaultEnv for them.
41 template <>
42 std::unique_ptr<test::TestCompilationEnvironment1>
CreateDefaultEnv()43 CompilationEnvironments::CreateDefaultEnv<test::TestCompilationEnvironment1>() {
44   auto env = std::make_unique<test::TestCompilationEnvironment1>();
45   env->set_some_flag(100);
46   return env;
47 }
48 
49 namespace {
50 
51 namespace op = ::xla::testing::opcode_matchers;
52 
53 class HloModuleTest : public HloTestBase {
54  protected:
HloModuleTest()55   HloModuleTest() {}
56 
57   // Create a computation which returns a constant.
CreateConstantComputation()58   std::unique_ptr<HloComputation> CreateConstantComputation() {
59     auto builder = HloComputation::Builder("Constant");
60     builder.AddInstruction(
61         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
62     return builder.Build();
63   }
64 
65   // Creates a computation which calls the given zero-parameter computations.
CreateCallComputation(absl::Span<HloComputation * const> computations)66   std::unique_ptr<HloComputation> CreateCallComputation(
67       absl::Span<HloComputation* const> computations) {
68     auto builder = HloComputation::Builder("Call");
69     for (auto computation : computations) {
70       builder.AddInstruction(
71           HloInstruction::CreateCall(r0f32_, {}, computation));
72     }
73     return builder.Build();
74   }
75 
76   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
77 };
78 
TEST_F(HloModuleTest,OneComputationPostOrder)79 TEST_F(HloModuleTest, OneComputationPostOrder) {
80   // Create a module with a single computation.
81   auto module = CreateNewVerifiedModule();
82   auto computation = module->AddEntryComputation(CreateConstantComputation());
83 
84   EXPECT_THAT(module->MakeComputationPostOrder(),
85               ::testing::ElementsAre(computation));
86 }
87 
TEST_F(HloModuleTest,TwoComputationsPostOrder)88 TEST_F(HloModuleTest, TwoComputationsPostOrder) {
89   // Create a module with two unconnected computations.
90   auto module = CreateNewVerifiedModule();
91   auto computation1 = module->AddEntryComputation(CreateConstantComputation());
92   auto computation2 =
93       module->AddEmbeddedComputation(CreateConstantComputation());
94 
95   EXPECT_THAT(module->MakeComputationPostOrder(),
96               ::testing::UnorderedElementsAre(computation1, computation2));
97 
98   // We specified the same name for both computations, but the HloModule should
99   // have made the names unique.
100   EXPECT_EQ(computation1->name(), "Constant");
101   EXPECT_EQ(computation2->name(), "Constant.1");
102 }
103 
TEST_F(HloModuleTest,CloneTest)104 TEST_F(HloModuleTest, CloneTest) {
105   // Create and copy a module with a diamond call graph of computations.
106   auto module = CreateNewVerifiedModule();
107   auto computation1 =
108       module->AddEmbeddedComputation(CreateConstantComputation());
109   auto computation2 =
110       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
111   auto computation3 =
112       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
113   module->AddEntryComputation(
114       CreateCallComputation({computation2, computation3}));
115   // Add a compilation environment to module
116   auto env = std::make_unique<test::TestCompilationEnvironment1>();
117   env->set_some_flag(10);
118   module->comp_envs().AddEnv(std::move(env));
119 
120   auto post_order = module->MakeComputationPostOrder();
121   auto cloned_module = module->Clone("copy");
122   auto post_order_copied = cloned_module->MakeComputationPostOrder();
123 
124   // Make sure module's CompilationEnvironments were copied to cloned_module
125   EXPECT_EQ(cloned_module->comp_envs()
126                 .GetEnv<test::TestCompilationEnvironment1>()
127                 .some_flag(),
128             10);
129 
130   EXPECT_EQ(post_order.size(), post_order_copied.size());
131   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
132        origin != post_order.end() && copied != post_order_copied.end();
133        ++origin, ++copied) {
134     EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
135   }
136 }
137 
TEST_F(HloModuleTest,CloneHasFusion)138 TEST_F(HloModuleTest, CloneHasFusion) {
139   auto module = CreateNewVerifiedModule();
140 
141   // Create the fused computation.
142   HloComputation* fused_computation;
143   {
144     auto b = HloComputation::Builder("Fused");
145     auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
146     b.AddInstruction(
147         HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x));
148     fused_computation = module->AddEmbeddedComputation(b.Build());
149   }
150 
151   // Create the entry computation.
152   {
153     auto b = HloComputation::Builder("Entry");
154     auto input = b.AddInstruction(
155         HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
156     b.AddInstruction(
157         HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput,
158                                      /*operands=*/{input}, fused_computation));
159     module->AddEntryComputation(b.Build());
160   }
161 
162   auto post_order = module->MakeComputationPostOrder();
163   auto cloned_module = module->Clone("copy");
164   auto post_order_copied = cloned_module->MakeComputationPostOrder();
165 
166   EXPECT_EQ(post_order.size(), post_order_copied.size());
167   for (auto origin = post_order.begin(), copied = post_order_copied.begin();
168        origin != post_order.end() && copied != post_order_copied.end();
169        ++origin, ++copied) {
170     if ((*origin)->name() == "Fused") {
171       // Clone of the fused computation is handled when its fusion instruction
172       // is cloned, which always use suffix ".clone".
173       EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name());
174     } else {
175       EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name());
176     }
177   }
178 }
179 
TEST_F(HloModuleTest,CloneCustomCallComputationToApply)180 TEST_F(HloModuleTest, CloneCustomCallComputationToApply) {
181   const char* const hlo_string = R"(
182 HloModule a_module
183 
184 add_s32 {
185   lhs = s32[] parameter(0)
186   rhs = s32[] parameter(1)
187   ROOT add = s32[] add(lhs, rhs)
188 }
189 
190 ENTRY entry () -> s32[] {
191   %c1 = s32[] constant(1)
192   %c2 = s32[] constant(2)
193   ROOT %custom-call =
194     s32[] custom-call(s32[] %c1, %c2),
195     custom_call_target="foo",
196     backend_config="this string is opaque",
197     to_apply=add_s32
198 })";
199   TF_ASSERT_OK_AND_ASSIGN(auto module,
200                           ParseAndReturnVerifiedModule(hlo_string));
201 
202   std::unique_ptr<HloModule> cloned_module = module->Clone();
203   HloComputation* cloned_computation =
204       cloned_module->GetComputationWithName("add_s32.clone");
205   HloInstruction* cloned_custom_call =
206       cloned_module->entry_computation()->GetInstructionWithName("custom-call");
207 
208   EXPECT_TRUE(cloned_computation->IsCustomCallComputation());
209   EXPECT_EQ(cloned_computation->CustomCallInstruction(), cloned_custom_call);
210 }
211 
TEST_F(HloModuleTest,CloneCustomCallComputationCalledComputations)212 TEST_F(HloModuleTest, CloneCustomCallComputationCalledComputations) {
213   const char* const hlo_string = R"(
214 HloModule a_module
215 
216 add_s32_0 {
217   lhs = s32[] parameter(0)
218   rhs = s32[] parameter(1)
219   ROOT add = s32[] add(lhs, rhs)
220 }
221 
222 add_s32_1 {
223   lhs = s32[] parameter(0)
224   rhs = s32[] parameter(1)
225   ROOT add = s32[] add(lhs, rhs)
226 }
227 
228 ENTRY entry () -> s32[] {
229   %c1 = s32[] constant(1)
230   %c2 = s32[] constant(2)
231   ROOT %custom-call =
232     s32[] custom-call(s32[] %c1, %c2),
233     custom_call_target="foo",
234     backend_config="this string is opaque",
235     called_computations={%add_s32_0, %add_s32_1}
236 })";
237   TF_ASSERT_OK_AND_ASSIGN(auto module,
238                           ParseAndReturnVerifiedModule(hlo_string));
239 
240   std::unique_ptr<HloModule> cloned_module = module->Clone();
241   HloComputation* cloned_computation_0 =
242       cloned_module->GetComputationWithName("add_s32_0.clone");
243   HloComputation* cloned_computation_1 =
244       cloned_module->GetComputationWithName("add_s32_1.clone");
245   HloInstruction* cloned_custom_call =
246       cloned_module->entry_computation()->GetInstructionWithName("custom-call");
247 
248   EXPECT_TRUE(cloned_computation_0->IsCustomCallComputation());
249   EXPECT_EQ(cloned_computation_0->CustomCallInstruction(), cloned_custom_call);
250   EXPECT_TRUE(cloned_computation_1->IsCustomCallComputation());
251   EXPECT_EQ(cloned_computation_1->CustomCallInstruction(), cloned_custom_call);
252 }
253 
TEST_F(HloModuleTest,CloneFusionComputation)254 TEST_F(HloModuleTest, CloneFusionComputation) {
255   const char* const hlo_string = R"(
256 HloModule a_module
257 
258 fused_computation () -> s32[] {
259   ROOT %result = s32[] parameter(0)
260 }
261 
262 ENTRY main {
263   %c = s32[] constant(1)
264   ROOT %fusion = s32[] fusion(%c), kind=kLoop, calls=fused_computation
265 }
266 )";
267   TF_ASSERT_OK_AND_ASSIGN(auto module,
268                           ParseAndReturnVerifiedModule(hlo_string));
269 
270   std::unique_ptr<HloModule> cloned_module = module->Clone();
271   HloComputation* cloned_computation =
272       cloned_module->GetComputationWithName("fused_computation.clone");
273   HloInstruction* cloned_fusion_instr =
274       cloned_module->entry_computation()->GetInstructionWithName("fusion");
275 
276   EXPECT_TRUE(cloned_computation->IsFusionComputation());
277   EXPECT_EQ(cloned_computation->FusionInstruction(), cloned_fusion_instr);
278 }
279 
TEST_F(HloModuleTest,DiamondComputationsPostOrder)280 TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
281   // Create a module with a diamond call graph of computations.
282   auto module = CreateNewVerifiedModule();
283   auto computation1 =
284       module->AddEmbeddedComputation(CreateConstantComputation());
285   auto computation2 =
286       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
287   auto computation3 =
288       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
289   auto computation4 = module->AddEntryComputation(
290       CreateCallComputation({computation2, computation3}));
291 
292   auto post_order = module->MakeComputationPostOrder();
293   EXPECT_THAT(post_order,
294               ::testing::UnorderedElementsAre(computation1, computation2,
295                                               computation3, computation4));
296   EXPECT_EQ(post_order.back(), computation4);
297   EXPECT_EQ(post_order.front(), computation1);
298 }
299 
TEST_F(HloModuleTest,LargeConstantToString)300 TEST_F(HloModuleTest, LargeConstantToString) {
301   // Create a module with a single computation.
302   auto module = CreateNewVerifiedModule();
303   auto builder = HloComputation::Builder("Constant");
304   std::vector<float> values(16, 42.0);
305   builder.AddInstruction(
306       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(values)));
307   module->AddEntryComputation(builder.Build());
308 
309   EXPECT_EQ(
310       "HloModule LargeConstantToString, "
311       "entry_computation_layout={()->f32[16]{0}}\n\nENTRY %Constant () -> "
312       "f32[16] {\n  ROOT %constant = f32[16]{0} constant({...})\n}\n\n",
313       module->ToString(HloPrintOptions().set_print_large_constants(false)));
314 
315   EXPECT_EQ(
316       "HloModule LargeConstantToString, "
317       "entry_computation_layout={()->f32[16]{0}}\n\nENTRY %Constant () -> "
318       "f32[16] {\n  ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, "
319       "42, 42, 42, 42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n",
320       module->ToString(HloPrintOptions().set_print_large_constants(true)));
321 }
322 
TEST_F(HloModuleTest,UniqueModuleId)323 TEST_F(HloModuleTest, UniqueModuleId) {
324   auto module_a = CreateNewVerifiedModule();
325   auto module_b = CreateNewVerifiedModule();
326   EXPECT_NE(module_a->unique_id(), module_b->unique_id());
327 }
328 
TEST_F(HloModuleTest,ProtoSerializationWithoutSchedule)329 TEST_F(HloModuleTest, ProtoSerializationWithoutSchedule) {
330   const std::string text = R"(
331 HloModule axpy_module
332 
333 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
334   %alpha = f32[] parameter(0)
335   %x = f32[2,4]{1,0} parameter(1)
336   %y = f32[2,4]{1,0} parameter(2)
337   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
338   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
339   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
340 }
341 )";
342   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
343   ASSERT_FALSE(module->has_schedule());
344   TF_ASSERT_OK_AND_ASSIGN(
345       auto module_copy,
346       HloModule::CreateFromProto(module->ToProto(), module->config()));
347   ASSERT_FALSE(module_copy->has_schedule());
348 }
349 
TEST_F(HloModuleTest,ProtoSerializationWithSchedule)350 TEST_F(HloModuleTest, ProtoSerializationWithSchedule) {
351   const std::string text = R"(
352 HloModule axpy_module, is_scheduled=true
353 
354 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
355   %alpha = f32[] parameter(0)
356   %x = f32[2,4]{1,0} parameter(1)
357   %y = f32[2,4]{1,0} parameter(2)
358   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
359   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
360   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
361 }
362 )";
363   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
364   ASSERT_TRUE(module->has_schedule());
365   TF_ASSERT_OK_AND_ASSIGN(
366       auto module_copy,
367       HloModule::CreateFromProto(module->ToProto(), module->config()));
368   ASSERT_TRUE(module_copy->has_schedule());
369   TF_ASSERT_OK(module_copy->schedule().Verify());
370   EXPECT_EQ(module_copy->schedule().sequences().size(), 1);
371   ASSERT_TRUE(module_copy->schedule().is_computation_scheduled(
372       module_copy->entry_computation()));
373   EXPECT_THAT(
374       module_copy->schedule()
375           .sequence(module_copy->entry_computation())
376           .instructions(),
377       ::testing::ElementsAre(op::Parameter(), op::Parameter(), op::Parameter(),
378                              op::Broadcast(), op::Multiply(), op::Add()));
379 }
380 
TEST_F(HloModuleTest,ProtoSerializationPreservesIds)381 TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
382   // Verify that serializing then deserializing an HLO proto preserves the
383   // unique IDs of the instruction and module.
384   const std::string text =
385       R"(HloModule ReduceR3ToR2_module
386 
387 add_F32.v3 {
388   lhs = f32[] parameter(0)
389   rhs = f32[] parameter(1)
390   ROOT add = f32[] add(lhs, rhs)
391 }
392 
393 ENTRY ReduceR3ToR2.v3 {
394   input = f32[8,16,256]{2,1,0} parameter(0)
395   constant = f32[] constant(0)
396   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
397 }
398 )";
399   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
400 
401   // Perform various transformations on the graph:
402   //
403   //  * clone the reduction function
404   //  * replace use of reduction function with the clone.
405   //  * add a random instruction to the entry computation.
406   //
407   // This will create instruction and computation IDs which are interesting:
408   // not consecutive and not densely packed.
409   HloComputation* entry = module->entry_computation();
410   HloInstruction* root = entry->root_instruction();
411   HloComputation* reduction = root->to_apply();
412   HloComputation* reduction_clone =
413       module->AddEmbeddedComputation(reduction->Clone());
414   root->set_to_apply(reduction_clone);
415   TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
416   HloInstruction* negate = entry->AddInstruction(
417       HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
418   entry->set_root_instruction(negate);
419 
420   // Schedule the transformed module, this verifies that the serialized schedule
421   // is robust against non-consecutive IDs as well (b/114712358).
422   auto size_fn = [](const BufferValue& buffer) {
423     return ShapeUtil::ByteSizeOf(buffer.shape());
424   };
425   HloMemoryScheduler scheduler(size_fn);
426   TF_ASSERT_OK(scheduler.Run(module.get()).status());
427   ASSERT_TRUE(module->has_schedule());
428 
429   // Serialize and deserialize and verify that the instruction and computations
430   // unique ids are the same.
431   TF_ASSERT_OK_AND_ASSIGN(
432       auto module_copy,
433       HloModule::CreateFromProto(module->ToProto(), module->config()));
434 
435   // The module IDs should *not* be the same because module ids must be globally
436   // unique.
437   EXPECT_NE(module->unique_id(), module_copy->unique_id());
438 
439   // Verify that the computations and instructions all have the same unique id.
440   auto computation_copy = module_copy->computations();
441   auto computation_copy_it = computation_copy.begin();
442   for (const HloComputation* computation_orig : module->computations()) {
443     const HloComputation* computation_copy = *computation_copy_it++;
444     EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
445         << absl::StrFormat(
446                "ID of original computation %s != ID of deserialized "
447                "computation %s: %d != %d",
448                computation_orig->name(), computation_copy->name(),
449                computation_orig->unique_id(), computation_copy->unique_id());
450 
451     auto instruction_copy_it = computation_copy->instructions().begin();
452     for (const HloInstruction* instruction_orig :
453          computation_orig->instructions()) {
454       const HloInstruction* instruction_copy = *instruction_copy_it++;
455       EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
456           << absl::StrFormat(
457                  "ID of original instruction %s != ID of deserialized "
458                  "instruction %s: %d != %d",
459                  instruction_orig->name(), instruction_copy->name(),
460                  instruction_orig->unique_id(), instruction_copy->unique_id());
461     }
462   }
463 
464   // Verify that the next unique ID which the module would have handed out is
465   // greater than the unique id of any instruction.
466   int next_id = module_copy->NewUniqueInstructionId();
467   for (const HloComputation* computation : module_copy->computations()) {
468     for (const HloInstruction* instruction : computation->instructions()) {
469       EXPECT_GT(next_id, instruction->unique_id());
470     }
471   }
472 }
473 
TEST_F(HloModuleTest,VerifyReplaceComputationsWithSortOp)474 TEST_F(HloModuleTest, VerifyReplaceComputationsWithSortOp) {
475   const std::string text = R"(
476   HloModule sort
477 
478   compare {
479       p.0.lhs = f32[] parameter(0)
480       p.0.rhs = f32[] parameter(1)
481       p.1.lhs = f32[] parameter(2)
482       p.1.rhs = f32[] parameter(3)
483       ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
484   }
485 
486   ENTRY top {
487     p.0 = f32[32] parameter(0)
488     p.1 = f32[32] parameter(1)
489     ROOT %sort.148.1589 = (f32[32], f32[32]) sort(p.0, p.1), dimensions={0}, to_apply=compare
490   }
491   )";
492 
493   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
494 
495   // Create a replacement computation
496   HloComputation* new_comp;
497   {
498     auto b = HloComputation::Builder("Fused");
499     auto p0 =
500         b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "p0"));
501     auto p1 =
502         b.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "p1"));
503     b.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "p2"));
504     b.AddInstruction(HloInstruction::CreateParameter(3, r0f32_, "p3"));
505     b.AddInstruction(HloInstruction::CreateCompare(
506         ShapeUtil::MakeShape(PRED, {}), p0, p1, ComparisonDirection::kGt));
507     new_comp = module->AddEmbeddedComputation(b.Build());
508   }
509 
510   HloComputation* entry = module->entry_computation();
511   HloInstruction* root = entry->root_instruction();
512   EXPECT_EQ(root->to_apply()->root_instruction()->opcode(),
513             HloOpcode::kCompare);
514   EXPECT_EQ(root->to_apply()->root_instruction()->comparison_direction(),
515             ComparisonDirection::kLt);
516 
517   absl::flat_hash_map<HloComputation*, HloComputation*> replacement;
518   replacement[root->to_apply()] = new_comp;
519   module->ReplaceComputations(replacement);
520 
521   EXPECT_EQ(root->to_apply(), new_comp);
522 }
523 
TEST_F(HloModuleTest,OneComputationAllAllowed)524 TEST_F(HloModuleTest, OneComputationAllAllowed) {
525   // Create a module with a single computation and
526   // ensure it is available when placed in the allow-list
527   auto module = CreateNewVerifiedModule();
528   auto computation = module->AddEntryComputation(CreateConstantComputation());
529 
530   absl::flat_hash_set<HloComputation*> allowList = {computation};
531   EXPECT_THAT(
532       module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList),
533       ::testing::ElementsAre(computation));
534 }
535 
TEST_F(HloModuleTest,OneComputationAllFiltered)536 TEST_F(HloModuleTest, OneComputationAllFiltered) {
537   // Create a module with a single computation.
538   auto module = CreateNewVerifiedModule();
539   module->AddEntryComputation(CreateConstantComputation());
540 
541   absl::flat_hash_set<HloComputation*> allowList = {};
542   module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
543   EXPECT_THAT(
544       module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList),
545       ::testing::IsEmpty());
546 }
547 
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllAllowed)548 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllAllowed) {
549   // Create a module with a diamond call graph of computations.
550   auto module = CreateNewVerifiedModule();
551   auto computation1 =
552       module->AddEmbeddedComputation(CreateConstantComputation());
553   auto computation2 =
554       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
555   auto computation3 =
556       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
557   auto computation4 = module->AddEntryComputation(
558       CreateCallComputation({computation2, computation3}));
559 
560   absl::flat_hash_set<HloComputation*> allowList = {computation1, computation2,
561                                                     computation3, computation4};
562   auto post_order =
563       module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
564   EXPECT_THAT(post_order,
565               ::testing::UnorderedElementsAre(computation1, computation2,
566                                               computation3, computation4));
567   EXPECT_EQ(post_order.back(), computation4);
568   EXPECT_EQ(post_order.front(), computation1);
569 }
570 
TEST_F(HloModuleTest,DiamondComputationsPostOrderMiddleFiltered)571 TEST_F(HloModuleTest, DiamondComputationsPostOrderMiddleFiltered) {
572   // Create a module with a diamond call graph of computations.
573   auto module = CreateNewVerifiedModule();
574   auto computation1 =
575       module->AddEmbeddedComputation(CreateConstantComputation());
576   auto computation2 =
577       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
578   auto computation3 =
579       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
580   auto computation4 = module->AddEntryComputation(
581       CreateCallComputation({computation2, computation3}));
582 
583   absl::flat_hash_set<HloComputation*> allowList = {computation1, computation4};
584   auto post_order =
585       module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
586   EXPECT_THAT(post_order,
587               ::testing::UnorderedElementsAre(computation1, computation4));
588 }
589 
TEST_F(HloModuleTest,DiamondComputationsPostOrderAllFiltered)590 TEST_F(HloModuleTest, DiamondComputationsPostOrderAllFiltered) {
591   // Create a module with a diamond call graph of computations.
592   auto module = CreateNewVerifiedModule();
593   auto computation1 =
594       module->AddEmbeddedComputation(CreateConstantComputation());
595   auto computation2 =
596       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
597   auto computation3 =
598       module->AddEmbeddedComputation(CreateCallComputation({computation1}));
599   module->AddEntryComputation(
600       CreateCallComputation({computation2, computation3}));
601 
602   absl::flat_hash_set<HloComputation*> allowList = {};
603   auto post_order =
604       module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList);
605   EXPECT_THAT(
606       module->MakeComputationPostOrder(/*execution_threads=*/{}, allowList),
607       ::testing::IsEmpty());
608 }
609 
TEST_F(HloModuleTest,TwoComputationsFilterexecution_threads)610 TEST_F(HloModuleTest, TwoComputationsFilterexecution_threads) {
611   // Create a module with two computations with different execution_threads and
612   // ensure thread name filtering can return proper computations.
613   HloComputation::Builder builder(TestName());
614   constexpr char kParallelThreadName[] = "parallel_thread";
615   // Create a call instruction containing a single binary operation.
616   auto constant1 = builder.AddInstruction(
617       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
618   auto constant2 = builder.AddInstruction(
619       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
620   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
621       r0f32_, HloOpcode::kAdd, constant1, constant2));
622   auto module = CreateNewVerifiedModule();
623   auto* main_thread_computation = module->AddEntryComputation(builder.Build());
624   TF_ASSERT_OK_AND_ASSIGN(
625       auto* async_done,
626       main_thread_computation->CreateAsyncInstructions(
627           add, {ShapeUtil::MakeScalarShape(U32)}, kParallelThreadName));
628   auto* parallel_thread_computation = async_done->async_wrapped_computation();
629 
630   EXPECT_THAT(
631       module->MakeComputationPostOrder({HloInstruction::kMainExecutionThread}),
632       ::testing::ElementsAre(main_thread_computation));
633   EXPECT_THAT(module->MakeComputationPostOrder(),
634               ::testing::ElementsAre(parallel_thread_computation,
635                                      main_thread_computation));
636   EXPECT_THAT(module->MakeComputationPostOrder({kParallelThreadName}),
637               ::testing::ElementsAre(parallel_thread_computation));
638   // Test that computations(execution_thread) return the expected values.
639   int num_all_computations = 0;
640   for ([[maybe_unused]] const HloComputation* comp :
641        module->computations(/*execution_threads=*/{})) {
642     ++num_all_computations;
643   }
644   EXPECT_EQ(num_all_computations, 2);
645   int num_main_computations = 0;
646   for (const HloComputation* comp :
647        module->computations({HloInstruction::kMainExecutionThread})) {
648     ++num_main_computations;
649     EXPECT_EQ(comp->execution_thread(), HloInstruction::kMainExecutionThread);
650   }
651   EXPECT_EQ(num_main_computations, 1);
652   int num_parallel_computations = 0;
653   for (const HloComputation* comp :
654        module->computations({kParallelThreadName})) {
655     ++num_parallel_computations;
656     EXPECT_EQ(comp->execution_thread(), kParallelThreadName);
657   }
658   EXPECT_EQ(num_parallel_computations, 1);
659 }
660 
661 }  // namespace
662 
663 }  // namespace xla
664