xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_pass_pipeline_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 
26 namespace xla {
27 namespace {
28 
29 using ::testing::ElementsAre;
30 using ::testing::SizeIs;
31 using ::testing::StrEq;
32 
33 class HloPassPipelineTest : public HloTestBase {
34  protected:
ParseModuleGroup(absl::Span<const std::string> hlo_strings)35   StatusOr<HloModuleGroup> ParseModuleGroup(
36       absl::Span<const std::string> hlo_strings) {
37     HloModuleGroup group(TestName());
38     for (const std::string& hlo_string : hlo_strings) {
39       TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
40                           ParseAndReturnVerifiedModule(hlo_string));
41       group.push_back(std::move(module));
42     }
43     return std::move(group);
44   }
45 };
46 
47 // A module pass which renames instructions named 'foo' to 'bar'.
48 class FooToBarModulePass : public HloModulePass {
name() const49   absl::string_view name() const override { return "foo2bar"; }
50 
51   using HloPassInterface::Run;
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)52   StatusOr<bool> Run(HloModule* module,
53                      const absl::flat_hash_set<absl::string_view>&
54                          execution_threads) override {
55     bool changed = false;
56     for (HloComputation* computation :
57          module->computations(execution_threads)) {
58       for (HloInstruction* instruction : computation->instructions()) {
59         if (instruction->name() == "foo") {
60           instruction->SetAndSanitizeName("bar");
61           changed = true;
62         }
63       }
64     }
65     return changed;
66   }
67 };
68 
69 // A module pass which renames root instructions names in reverse string order,
70 // e.g. "xyz" becomes "zyx".
71 class ReverseStringModulePass : public HloModulePass {
name() const72   absl::string_view name() const override { return "reverse"; }
73 
74   using HloPassInterface::Run;
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)75   StatusOr<bool> Run(HloModule* module,
76                      const absl::flat_hash_set<absl::string_view>&
77                          execution_threads) override {
78     bool changed = false;
79     for (HloComputation* computation :
80          module->computations(execution_threads)) {
81       HloInstruction* root = computation->root_instruction();
82       std::string name = root->name();
83       std::reverse(name.begin(), name.end());
84       root->SetAndSanitizeName(name);
85       changed = true;
86     }
87     return changed;
88   }
89 };
90 
91 // A module group pass which renames instructions named 'baz' to 'qux'.
92 class BazToQuxModuleGroupPass : public HloModuleGroupPass {
name() const93   absl::string_view name() const override { return "baz2qux"; }
94 
95   using HloPassInterface::RunOnModuleGroup;
RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)96   StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group,
97                                   const absl::flat_hash_set<absl::string_view>&
98                                       execution_threads) override {
99     bool changed = false;
100     for (HloModule* module : module_group->modules()) {
101       for (HloComputation* computation :
102            module->computations(execution_threads)) {
103         for (HloInstruction* instruction : computation->instructions()) {
104           if (instruction->name() == "baz") {
105             instruction->SetAndSanitizeName("qux");
106             changed = true;
107           }
108         }
109       }
110     }
111     return changed;
112   }
113 };
114 
115 // An invariant checker pass which returns an error if there exists an
116 // instruction named 'bar'.
117 class BarBlowerUpper : public HloModulePass {
name() const118   absl::string_view name() const override { return "bar-blower-upper"; }
119 
120   using HloPassInterface::Run;
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)121   StatusOr<bool> Run(HloModule* module,
122                      const absl::flat_hash_set<absl::string_view>&
123                          execution_threads) override {
124     for (HloComputation* computation :
125          module->computations(execution_threads)) {
126       for (HloInstruction* instruction : computation->instructions()) {
127         if (instruction->name() == "bar") {
128           return InternalError("Module has instruction named bar");
129         }
130       }
131     }
132     return false;
133   }
134 };
135 
TEST_F(HloPassPipelineTest,ModulePassChanged)136 TEST_F(HloPassPipelineTest, ModulePassChanged) {
137   // Test an HLO module pass which changes a module.
138   const std::string module_str = R"(
139 HloModule ModulePassChanged
140 
141 ENTRY main {
142   a = f32[] parameter(0)
143   b = f32[] parameter(1)
144   ROOT foo = f32[] multiply(a, b)
145 }
146 )";
147   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
148                           ParseAndReturnVerifiedModule(module_str));
149   HloPassPipeline pipeline(TestName());
150   pipeline.AddPass<FooToBarModulePass>();
151 
152   HloInstruction* root = module->entry_computation()->root_instruction();
153   EXPECT_EQ(root->name(), "foo");
154   TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
155   EXPECT_TRUE(changed);
156   EXPECT_EQ(root->name(), "bar");
157 }
158 
TEST_F(HloPassPipelineTest,ModulePassUnchanged)159 TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
160   // Test an HLO module pass which does not change a module.
161   const std::string module_str = R"(
162 HloModule ModulePassUnchanged
163 
164 ENTRY main {
165   a = f32[] parameter(0)
166   b = f32[] parameter(1)
167   ROOT blahblah = f32[] multiply(a, b)
168 }
169 )";
170   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
171                           ParseAndReturnVerifiedModule(module_str));
172   HloPassPipeline pipeline(TestName());
173   pipeline.AddPass<FooToBarModulePass>();
174 
175   TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
176   EXPECT_FALSE(changed);
177 }
178 
TEST_F(HloPassPipelineTest,ModulePassChangedForParallelThread)179 TEST_F(HloPassPipelineTest, ModulePassChangedForParallelThread) {
180   // Test an HLO module pass which changes a module.
181   const std::string module_str = R"(
182 HloModule ModulePassChanged
183 %async_builder {
184   %p0 = f32[10] parameter(0)
185   %p1 = f32[10] parameter(1)
186   ROOT %foo = add(%p0, %p1)
187 }, execution_thread="parallel_thread"
188 
189 
190 ENTRY %Entry (p0: f32[10], p1: f32[10]) -> f32[10] {
191   %p0 = f32[10] parameter(0)
192   %p1 = f32[10] parameter(1)
193   %async-start = ((f32[10], f32[10]), f32[10], s32[]) async-start(f32[10] %p0, f32[10] %p1), async_execution_thread="parallel_thread",calls=%async_builder
194   ROOT %baz = f32[10]{0} async-done(((f32[10], f32[10]), f32[10], s32[]) %async-start), async_execution_thread="parallel_thread", calls=%async_builder
195 }
196 )";
197   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
198                           ParseAndReturnVerifiedModule(module_str));
199   HloPassPipeline pipeline(TestName());
200   pipeline.AddPass<ReverseStringModulePass>();
201 
202   HloInstruction* main_root = module->entry_computation()->root_instruction();
203   HloInstruction* parallel_thread_root =
204       main_root->async_wrapped_computation()->root_instruction();
205   EXPECT_EQ(main_root->name(), "baz");
206   EXPECT_EQ(parallel_thread_root->name(), "foo");
207   TF_ASSERT_OK_AND_ASSIGN(bool changed,
208                           pipeline.Run(module.get(), {"parallel_thread"}));
209   EXPECT_TRUE(changed);
210   EXPECT_EQ(main_root->name(), "baz");
211   EXPECT_EQ(parallel_thread_root->name(), "oof");
212 }
213 
TEST_F(HloPassPipelineTest,ModulePassChangedForAllexecution_threads)214 TEST_F(HloPassPipelineTest, ModulePassChangedForAllexecution_threads) {
215   // Test an HLO module pass which changes a module.
216   const std::string module_str = R"(
217 HloModule ModulePassChanged
218 %async_builder {
219   %p0 = f32[10] parameter(0)
220   %p1 = f32[10] parameter(1)
221   ROOT %foo = add(%p0, %p1)
222 
223 }, execution_thread="parallel_thread"
224 
225 
226 ENTRY %Entry (p0: f32[10], p1: f32[10]) -> f32[10] {
227   %p0 = f32[10] parameter(0)
228   %p1 = f32[10] parameter(1)
229   %async-start = ((f32[10], f32[10]), f32[10], s32[]) async-start(f32[10] %p0, f32[10] %p1), async_execution_thread="parallel_thread",calls=%async_builder
230   ROOT %baz = f32[10]{0} async-done(((f32[10], f32[10]), f32[10], s32[]) %async-start), async_execution_thread="parallel_thread", calls=%async_builder
231 }
232 )";
233   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
234                           ParseAndReturnVerifiedModule(module_str));
235   HloPassPipeline pipeline(TestName());
236   pipeline.AddPass<ReverseStringModulePass>();
237 
238   HloInstruction* main_root = module->entry_computation()->root_instruction();
239   HloInstruction* parallel_thread_root =
240       main_root->async_wrapped_computation()->root_instruction();
241   EXPECT_EQ(main_root->name(), "baz");
242   EXPECT_EQ(parallel_thread_root->name(), "foo");
243   TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
244   EXPECT_TRUE(changed);
245   EXPECT_EQ(main_root->name(), "zab");
246   EXPECT_EQ(parallel_thread_root->name(), "oof");
247 }
248 
TEST_F(HloPassPipelineTest,MixedPipeline)249 TEST_F(HloPassPipelineTest, MixedPipeline) {
250   // Test a pipeline with both a module pass and a module group pass.
251   const std::string module_0_str = R"(
252 HloModule MixedPipeline.1
253 
254 ENTRY main {
255   a = f32[] parameter(0)
256   b = f32[] parameter(1)
257   ROOT baz = f32[] multiply(a, b)
258 }
259 )";
260   const std::string module_1_str = R"(
261 HloModule MixedPipeline.0
262 
263 ENTRY main {
264   a = f32[] parameter(0)
265   b = f32[] parameter(1)
266   ROOT foo = f32[] multiply(a, b)
267 }
268 )";
269 
270   TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
271                           ParseModuleGroup({module_0_str, module_1_str}));
272 
273   HloPassPipeline pipeline(TestName());
274   pipeline.AddPass<BazToQuxModuleGroupPass>();
275   pipeline.AddPass<FooToBarModulePass>();
276 
277   HloInstruction* root0 =
278       module_group.module(0).entry_computation()->root_instruction();
279   HloInstruction* root1 =
280       module_group.module(1).entry_computation()->root_instruction();
281   EXPECT_EQ(root0->name(), "baz");
282   EXPECT_EQ(root1->name(), "foo");
283 
284   TF_ASSERT_OK_AND_ASSIGN(bool changed,
285                           pipeline.RunOnModuleGroup(&module_group));
286   EXPECT_TRUE(changed);
287 
288   EXPECT_EQ(root0->name(), "qux");
289   EXPECT_EQ(root1->name(), "bar");
290 }
291 
TEST_F(HloPassPipelineTest,InvariantChecker)292 TEST_F(HloPassPipelineTest, InvariantChecker) {
293   const std::string module_str = R"(
294 HloModule InvariantChecker
295 
296 ENTRY main {
297   a = f32[] parameter(0)
298   b = f32[] parameter(1)
299   ROOT foo = f32[] multiply(a, b)
300 }
301 )";
302   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
303                           ParseAndReturnVerifiedModule(module_str));
304   {
305     // Run a pipeline with just the invariant checker. It should not fail
306     // because there is no 'bar' instruction in the module.
307     HloPassPipeline pipeline(TestName());
308     pipeline.AddInvariantChecker<BarBlowerUpper>();
309 
310     TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
311     EXPECT_FALSE(changed);
312   }
313 
314   {
315     // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
316     // which fails if there is an instruction named 'bar'.
317     HloPassPipeline pipeline(TestName());
318     pipeline.AddInvariantChecker<BarBlowerUpper>();
319     pipeline.AddPass<FooToBarModulePass>();
320 
321     Status status = pipeline.Run(module.get()).status();
322     ASSERT_IS_NOT_OK(status);
323     EXPECT_THAT(status.error_message(),
324                 ::testing::HasSubstr("Module has instruction named bar"));
325     EXPECT_THAT(status.error_message(),
326                 ::testing::HasSubstr("Failed after foo2bar"));
327   }
328 
329   {
330     // Run the invariant-checker only pipeline again. It should fail this time.
331     HloPassPipeline pipeline(TestName());
332     pipeline.AddInvariantChecker<BarBlowerUpper>();
333 
334     Status status = pipeline.Run(module.get()).status();
335     ASSERT_IS_NOT_OK(status);
336     EXPECT_THAT(status.error_message(),
337                 ::testing::HasSubstr("Module has instruction named bar"));
338     EXPECT_THAT(status.error_message(),
339                 ::testing::HasSubstr("Failed after pipeline-start"));
340   }
341 }
342 
TEST_F(HloPassPipelineTest,ModuleGroupPassOnModule)343 TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
344   // Running a module group pass on a module should produce an error.
345   const std::string module_str = R"(
346 HloModule ModuleGroupPassOnModule
347 
348 ENTRY main {
349   a = f32[] parameter(0)
350   b = f32[] parameter(1)
351   ROOT foo = f32[] multiply(a, b)
352 }
353 )";
354   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
355                           ParseAndReturnVerifiedModule(module_str));
356   HloPassPipeline pipeline(TestName());
357   pipeline.AddPass<BazToQuxModuleGroupPass>();
358 
359   Status status = pipeline.Run(module.get()).status();
360   ASSERT_IS_NOT_OK(status);
361   EXPECT_THAT(
362       status.error_message(),
363       ::testing::HasSubstr("Module group pass cannot be run on a module"));
364 }
365 
366 // Test that metadata is set when a module group goes through a pass pipeline.
TEST_F(HloPassPipelineTest,SetHloModuleMetadata)367 TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
368   HloModuleGroup module_group(TestName());
369   module_group.push_back(CreateNewVerifiedModule());
370   module_group.push_back(CreateNewVerifiedModule());
371 
372   HloPassPipeline pipeline(TestName());
373   pipeline.AddPass<BazToQuxModuleGroupPass>();
374   pipeline.AddPass<FooToBarModulePass>();
375   TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status());
376   ASSERT_THAT(module_group.modules(), SizeIs(2));
377 
378   std::vector<std::string> pass_names = {"pipeline-start", "baz2qux",
379                                          "foo2bar"};
380   std::string pipeline_name = std::string(pipeline.name());
381   for (const HloModule* module : module_group.modules()) {
382     const HloModuleMetadataProto& metadata = module->metadata().proto();
383     EXPECT_EQ(metadata.canonical_module_id(), module->unique_id());
384     EXPECT_EQ(metadata.module_group_name(), module_group.name());
385 
386     ASSERT_THAT(metadata.pass_metadata(), SizeIs(3));
387     for (int pass = 0; pass < metadata.pass_metadata().size(); pass++) {
388       const HloPassMetadata& pass_metadata = metadata.pass_metadata(pass);
389       EXPECT_NE(pass_metadata.pass_id(), 0);
390       EXPECT_THAT(pass_metadata.pass_name(), StrEq(pass_names[pass]));
391       EXPECT_THAT(pass_metadata.pipeline_name(), StrEq(pipeline_name));
392       EXPECT_FALSE(pass_metadata.module_changed());
393       EXPECT_EQ(pass_metadata.module_id(), module->unique_id());
394       EXPECT_THAT(pass_metadata.module_group_module_ids(),
395                   ElementsAre(module_group.module(0).unique_id(),
396                               module_group.module(1).unique_id()));
397       EXPECT_GT(pass_metadata.start_timestamp_usec(), 0);
398       EXPECT_LE(pass_metadata.start_timestamp_usec(),
399                 pass_metadata.end_timestamp_usec());
400     }
401   }
402 }
403 
404 }  // namespace
405 }  // namespace xla
406