1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_module.h"
21 #include "tensorflow/compiler/xla/service/hlo_parser.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25
26 namespace xla {
27 namespace {
28
29 using ::testing::ElementsAre;
30 using ::testing::SizeIs;
31 using ::testing::StrEq;
32
33 class HloPassPipelineTest : public HloTestBase {
34 protected:
ParseModuleGroup(absl::Span<const std::string> hlo_strings)35 StatusOr<HloModuleGroup> ParseModuleGroup(
36 absl::Span<const std::string> hlo_strings) {
37 HloModuleGroup group(TestName());
38 for (const std::string& hlo_string : hlo_strings) {
39 TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> module,
40 ParseAndReturnVerifiedModule(hlo_string));
41 group.push_back(std::move(module));
42 }
43 return std::move(group);
44 }
45 };
46
47 // A module pass which renames instructions named 'foo' to 'bar'.
48 class FooToBarModulePass : public HloModulePass {
name() const49 absl::string_view name() const override { return "foo2bar"; }
50
51 using HloPassInterface::Run;
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)52 StatusOr<bool> Run(HloModule* module,
53 const absl::flat_hash_set<absl::string_view>&
54 execution_threads) override {
55 bool changed = false;
56 for (HloComputation* computation :
57 module->computations(execution_threads)) {
58 for (HloInstruction* instruction : computation->instructions()) {
59 if (instruction->name() == "foo") {
60 instruction->SetAndSanitizeName("bar");
61 changed = true;
62 }
63 }
64 }
65 return changed;
66 }
67 };
68
69 // A module pass which renames root instructions names in reverse string order,
70 // e.g. "xyz" becomes "zyx".
71 class ReverseStringModulePass : public HloModulePass {
name() const72 absl::string_view name() const override { return "reverse"; }
73
74 using HloPassInterface::Run;
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)75 StatusOr<bool> Run(HloModule* module,
76 const absl::flat_hash_set<absl::string_view>&
77 execution_threads) override {
78 bool changed = false;
79 for (HloComputation* computation :
80 module->computations(execution_threads)) {
81 HloInstruction* root = computation->root_instruction();
82 std::string name = root->name();
83 std::reverse(name.begin(), name.end());
84 root->SetAndSanitizeName(name);
85 changed = true;
86 }
87 return changed;
88 }
89 };
90
91 // A module group pass which renames instructions named 'baz' to 'qux'.
92 class BazToQuxModuleGroupPass : public HloModuleGroupPass {
name() const93 absl::string_view name() const override { return "baz2qux"; }
94
95 using HloPassInterface::RunOnModuleGroup;
RunOnModuleGroup(HloModuleGroup * module_group,const absl::flat_hash_set<absl::string_view> & execution_threads)96 StatusOr<bool> RunOnModuleGroup(HloModuleGroup* module_group,
97 const absl::flat_hash_set<absl::string_view>&
98 execution_threads) override {
99 bool changed = false;
100 for (HloModule* module : module_group->modules()) {
101 for (HloComputation* computation :
102 module->computations(execution_threads)) {
103 for (HloInstruction* instruction : computation->instructions()) {
104 if (instruction->name() == "baz") {
105 instruction->SetAndSanitizeName("qux");
106 changed = true;
107 }
108 }
109 }
110 }
111 return changed;
112 }
113 };
114
115 // An invariant checker pass which returns an error if there exists an
116 // instruction named 'bar'.
117 class BarBlowerUpper : public HloModulePass {
name() const118 absl::string_view name() const override { return "bar-blower-upper"; }
119
120 using HloPassInterface::Run;
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)121 StatusOr<bool> Run(HloModule* module,
122 const absl::flat_hash_set<absl::string_view>&
123 execution_threads) override {
124 for (HloComputation* computation :
125 module->computations(execution_threads)) {
126 for (HloInstruction* instruction : computation->instructions()) {
127 if (instruction->name() == "bar") {
128 return InternalError("Module has instruction named bar");
129 }
130 }
131 }
132 return false;
133 }
134 };
135
TEST_F(HloPassPipelineTest,ModulePassChanged)136 TEST_F(HloPassPipelineTest, ModulePassChanged) {
137 // Test an HLO module pass which changes a module.
138 const std::string module_str = R"(
139 HloModule ModulePassChanged
140
141 ENTRY main {
142 a = f32[] parameter(0)
143 b = f32[] parameter(1)
144 ROOT foo = f32[] multiply(a, b)
145 }
146 )";
147 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
148 ParseAndReturnVerifiedModule(module_str));
149 HloPassPipeline pipeline(TestName());
150 pipeline.AddPass<FooToBarModulePass>();
151
152 HloInstruction* root = module->entry_computation()->root_instruction();
153 EXPECT_EQ(root->name(), "foo");
154 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
155 EXPECT_TRUE(changed);
156 EXPECT_EQ(root->name(), "bar");
157 }
158
TEST_F(HloPassPipelineTest,ModulePassUnchanged)159 TEST_F(HloPassPipelineTest, ModulePassUnchanged) {
160 // Test an HLO module pass which does not change a module.
161 const std::string module_str = R"(
162 HloModule ModulePassUnchanged
163
164 ENTRY main {
165 a = f32[] parameter(0)
166 b = f32[] parameter(1)
167 ROOT blahblah = f32[] multiply(a, b)
168 }
169 )";
170 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
171 ParseAndReturnVerifiedModule(module_str));
172 HloPassPipeline pipeline(TestName());
173 pipeline.AddPass<FooToBarModulePass>();
174
175 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
176 EXPECT_FALSE(changed);
177 }
178
TEST_F(HloPassPipelineTest,ModulePassChangedForParallelThread)179 TEST_F(HloPassPipelineTest, ModulePassChangedForParallelThread) {
180 // Test an HLO module pass which changes a module.
181 const std::string module_str = R"(
182 HloModule ModulePassChanged
183 %async_builder {
184 %p0 = f32[10] parameter(0)
185 %p1 = f32[10] parameter(1)
186 ROOT %foo = add(%p0, %p1)
187 }, execution_thread="parallel_thread"
188
189
190 ENTRY %Entry (p0: f32[10], p1: f32[10]) -> f32[10] {
191 %p0 = f32[10] parameter(0)
192 %p1 = f32[10] parameter(1)
193 %async-start = ((f32[10], f32[10]), f32[10], s32[]) async-start(f32[10] %p0, f32[10] %p1), async_execution_thread="parallel_thread",calls=%async_builder
194 ROOT %baz = f32[10]{0} async-done(((f32[10], f32[10]), f32[10], s32[]) %async-start), async_execution_thread="parallel_thread", calls=%async_builder
195 }
196 )";
197 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
198 ParseAndReturnVerifiedModule(module_str));
199 HloPassPipeline pipeline(TestName());
200 pipeline.AddPass<ReverseStringModulePass>();
201
202 HloInstruction* main_root = module->entry_computation()->root_instruction();
203 HloInstruction* parallel_thread_root =
204 main_root->async_wrapped_computation()->root_instruction();
205 EXPECT_EQ(main_root->name(), "baz");
206 EXPECT_EQ(parallel_thread_root->name(), "foo");
207 TF_ASSERT_OK_AND_ASSIGN(bool changed,
208 pipeline.Run(module.get(), {"parallel_thread"}));
209 EXPECT_TRUE(changed);
210 EXPECT_EQ(main_root->name(), "baz");
211 EXPECT_EQ(parallel_thread_root->name(), "oof");
212 }
213
TEST_F(HloPassPipelineTest,ModulePassChangedForAllexecution_threads)214 TEST_F(HloPassPipelineTest, ModulePassChangedForAllexecution_threads) {
215 // Test an HLO module pass which changes a module.
216 const std::string module_str = R"(
217 HloModule ModulePassChanged
218 %async_builder {
219 %p0 = f32[10] parameter(0)
220 %p1 = f32[10] parameter(1)
221 ROOT %foo = add(%p0, %p1)
222
223 }, execution_thread="parallel_thread"
224
225
226 ENTRY %Entry (p0: f32[10], p1: f32[10]) -> f32[10] {
227 %p0 = f32[10] parameter(0)
228 %p1 = f32[10] parameter(1)
229 %async-start = ((f32[10], f32[10]), f32[10], s32[]) async-start(f32[10] %p0, f32[10] %p1), async_execution_thread="parallel_thread",calls=%async_builder
230 ROOT %baz = f32[10]{0} async-done(((f32[10], f32[10]), f32[10], s32[]) %async-start), async_execution_thread="parallel_thread", calls=%async_builder
231 }
232 )";
233 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
234 ParseAndReturnVerifiedModule(module_str));
235 HloPassPipeline pipeline(TestName());
236 pipeline.AddPass<ReverseStringModulePass>();
237
238 HloInstruction* main_root = module->entry_computation()->root_instruction();
239 HloInstruction* parallel_thread_root =
240 main_root->async_wrapped_computation()->root_instruction();
241 EXPECT_EQ(main_root->name(), "baz");
242 EXPECT_EQ(parallel_thread_root->name(), "foo");
243 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
244 EXPECT_TRUE(changed);
245 EXPECT_EQ(main_root->name(), "zab");
246 EXPECT_EQ(parallel_thread_root->name(), "oof");
247 }
248
TEST_F(HloPassPipelineTest,MixedPipeline)249 TEST_F(HloPassPipelineTest, MixedPipeline) {
250 // Test a pipeline with both a module pass and a module group pass.
251 const std::string module_0_str = R"(
252 HloModule MixedPipeline.1
253
254 ENTRY main {
255 a = f32[] parameter(0)
256 b = f32[] parameter(1)
257 ROOT baz = f32[] multiply(a, b)
258 }
259 )";
260 const std::string module_1_str = R"(
261 HloModule MixedPipeline.0
262
263 ENTRY main {
264 a = f32[] parameter(0)
265 b = f32[] parameter(1)
266 ROOT foo = f32[] multiply(a, b)
267 }
268 )";
269
270 TF_ASSERT_OK_AND_ASSIGN(HloModuleGroup module_group,
271 ParseModuleGroup({module_0_str, module_1_str}));
272
273 HloPassPipeline pipeline(TestName());
274 pipeline.AddPass<BazToQuxModuleGroupPass>();
275 pipeline.AddPass<FooToBarModulePass>();
276
277 HloInstruction* root0 =
278 module_group.module(0).entry_computation()->root_instruction();
279 HloInstruction* root1 =
280 module_group.module(1).entry_computation()->root_instruction();
281 EXPECT_EQ(root0->name(), "baz");
282 EXPECT_EQ(root1->name(), "foo");
283
284 TF_ASSERT_OK_AND_ASSIGN(bool changed,
285 pipeline.RunOnModuleGroup(&module_group));
286 EXPECT_TRUE(changed);
287
288 EXPECT_EQ(root0->name(), "qux");
289 EXPECT_EQ(root1->name(), "bar");
290 }
291
TEST_F(HloPassPipelineTest,InvariantChecker)292 TEST_F(HloPassPipelineTest, InvariantChecker) {
293 const std::string module_str = R"(
294 HloModule InvariantChecker
295
296 ENTRY main {
297 a = f32[] parameter(0)
298 b = f32[] parameter(1)
299 ROOT foo = f32[] multiply(a, b)
300 }
301 )";
302 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
303 ParseAndReturnVerifiedModule(module_str));
304 {
305 // Run a pipeline with just the invariant checker. It should not fail
306 // because there is no 'bar' instruction in the module.
307 HloPassPipeline pipeline(TestName());
308 pipeline.AddInvariantChecker<BarBlowerUpper>();
309
310 TF_ASSERT_OK_AND_ASSIGN(bool changed, pipeline.Run(module.get()));
311 EXPECT_FALSE(changed);
312 }
313
314 {
315 // Run a pipeline which renames 'foo' to 'bar' then an invariant checker
316 // which fails if there is an instruction named 'bar'.
317 HloPassPipeline pipeline(TestName());
318 pipeline.AddInvariantChecker<BarBlowerUpper>();
319 pipeline.AddPass<FooToBarModulePass>();
320
321 Status status = pipeline.Run(module.get()).status();
322 ASSERT_IS_NOT_OK(status);
323 EXPECT_THAT(status.error_message(),
324 ::testing::HasSubstr("Module has instruction named bar"));
325 EXPECT_THAT(status.error_message(),
326 ::testing::HasSubstr("Failed after foo2bar"));
327 }
328
329 {
330 // Run the invariant-checker only pipeline again. It should fail this time.
331 HloPassPipeline pipeline(TestName());
332 pipeline.AddInvariantChecker<BarBlowerUpper>();
333
334 Status status = pipeline.Run(module.get()).status();
335 ASSERT_IS_NOT_OK(status);
336 EXPECT_THAT(status.error_message(),
337 ::testing::HasSubstr("Module has instruction named bar"));
338 EXPECT_THAT(status.error_message(),
339 ::testing::HasSubstr("Failed after pipeline-start"));
340 }
341 }
342
TEST_F(HloPassPipelineTest,ModuleGroupPassOnModule)343 TEST_F(HloPassPipelineTest, ModuleGroupPassOnModule) {
344 // Running a module group pass on a module should produce an error.
345 const std::string module_str = R"(
346 HloModule ModuleGroupPassOnModule
347
348 ENTRY main {
349 a = f32[] parameter(0)
350 b = f32[] parameter(1)
351 ROOT foo = f32[] multiply(a, b)
352 }
353 )";
354 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
355 ParseAndReturnVerifiedModule(module_str));
356 HloPassPipeline pipeline(TestName());
357 pipeline.AddPass<BazToQuxModuleGroupPass>();
358
359 Status status = pipeline.Run(module.get()).status();
360 ASSERT_IS_NOT_OK(status);
361 EXPECT_THAT(
362 status.error_message(),
363 ::testing::HasSubstr("Module group pass cannot be run on a module"));
364 }
365
366 // Test that metadata is set when a module group goes through a pass pipeline.
TEST_F(HloPassPipelineTest,SetHloModuleMetadata)367 TEST_F(HloPassPipelineTest, SetHloModuleMetadata) {
368 HloModuleGroup module_group(TestName());
369 module_group.push_back(CreateNewVerifiedModule());
370 module_group.push_back(CreateNewVerifiedModule());
371
372 HloPassPipeline pipeline(TestName());
373 pipeline.AddPass<BazToQuxModuleGroupPass>();
374 pipeline.AddPass<FooToBarModulePass>();
375 TF_ASSERT_OK(pipeline.RunOnModuleGroup(&module_group).status());
376 ASSERT_THAT(module_group.modules(), SizeIs(2));
377
378 std::vector<std::string> pass_names = {"pipeline-start", "baz2qux",
379 "foo2bar"};
380 std::string pipeline_name = std::string(pipeline.name());
381 for (const HloModule* module : module_group.modules()) {
382 const HloModuleMetadataProto& metadata = module->metadata().proto();
383 EXPECT_EQ(metadata.canonical_module_id(), module->unique_id());
384 EXPECT_EQ(metadata.module_group_name(), module_group.name());
385
386 ASSERT_THAT(metadata.pass_metadata(), SizeIs(3));
387 for (int pass = 0; pass < metadata.pass_metadata().size(); pass++) {
388 const HloPassMetadata& pass_metadata = metadata.pass_metadata(pass);
389 EXPECT_NE(pass_metadata.pass_id(), 0);
390 EXPECT_THAT(pass_metadata.pass_name(), StrEq(pass_names[pass]));
391 EXPECT_THAT(pass_metadata.pipeline_name(), StrEq(pipeline_name));
392 EXPECT_FALSE(pass_metadata.module_changed());
393 EXPECT_EQ(pass_metadata.module_id(), module->unique_id());
394 EXPECT_THAT(pass_metadata.module_group_module_ids(),
395 ElementsAre(module_group.module(0).unique_id(),
396 module_group.module(1).unique_id()));
397 EXPECT_GT(pass_metadata.start_timestamp_usec(), 0);
398 EXPECT_LE(pass_metadata.start_timestamp_usec(),
399 pass_metadata.end_timestamp_usec());
400 }
401 }
402 }
403
404 } // namespace
405 } // namespace xla
406