1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17
18 #include <string>
19
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/async_op_canonicalizer.h"
22 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/status_test_util.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/test.h"
40
41 namespace xla {
42 namespace {
43
44 using ::testing::ElementsAre;
45 using ::testing::IsEmpty;
46 using ::testing::UnorderedElementsAre;
47
48 // Test is parameterized on a bool which is whether the dataflow analysis is
49 // performed with SSA form.
50 class HloDataflowAnalysisTest : public HloTestBase,
51 public ::testing::WithParamInterface<bool> {
52 protected:
HloDataflowAnalysisTest()53 HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {}
54
55 // Run dataflow analysis on the member module. For convenience returns a
56 // reference to the generated analysis stored in analysis_.
RunAnalysis(bool ssa_form,bool bitcast_defines_value=false,bool run_dce=true)57 const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
58 bool bitcast_defines_value = false,
59 bool run_dce = true) {
60 AsyncOpCanonicalizer async_op_canonicalizer;
61 EXPECT_TRUE(async_op_canonicalizer.Run(module_.get()).ok());
62 if (run_dce) {
63 HloDCE dce;
64 EXPECT_TRUE(dce.Run(module_.get()).ok());
65 }
66 FlattenCallGraph flatten;
67 EXPECT_TRUE(flatten.Run(module_.get()).ok());
68 analysis_ =
69 HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
70 .value();
71 return *analysis_;
72 }
73
74 // Return a vector of the HloValues at the given program position.
HloValuesAt(const HloInstruction * instruction,const ShapeIndex & index={})75 const std::vector<const HloValue*>& HloValuesAt(
76 const HloInstruction* instruction, const ShapeIndex& index = {}) {
77 CHECK(analysis_ != nullptr);
78 return analysis_->GetValueSet(instruction, index).values();
79 }
80
81 // Returns true if the top-level values for instructions 'a' and 'b' may
82 // interfere. Precondition: 'a' and 'b' define array-shaped values.
InstructionsMayInterfere(const HloOrdering & ordering,const HloInstruction * a,const HloInstruction * b)83 bool InstructionsMayInterfere(const HloOrdering& ordering,
84 const HloInstruction* a,
85 const HloInstruction* b) {
86 EXPECT_FALSE(a->shape().IsTuple());
87 EXPECT_FALSE(b->shape().IsTuple());
88 return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
89 analysis_->GetValueDefinedAt(b), *analysis_);
90 }
91
CreateR0F32UnaryOpComputation(HloOpcode opcode)92 std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
93 HloOpcode opcode) {
94 HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
95 HloInstruction* param0 = builder.AddInstruction(
96 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
97 builder.AddInstruction(
98 HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
99 return builder.Build();
100 }
101
102 std::unique_ptr<HloModule> module_;
103 std::unique_ptr<HloDataflowAnalysis> analysis_;
104
105 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
106 const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
107 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
108 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
109 };
110
TEST_P(HloDataflowAnalysisTest,BinaryOperation)111 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
112 // Test the dataflow for a simple binary operation (Add).
113 auto builder = HloComputation::Builder(TestName());
114 auto constant1 = builder.AddInstruction(
115 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
116 auto constant2 = builder.AddInstruction(
117 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
118 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
119 scalar_shape_, HloOpcode::kAdd, constant1, constant2));
120 module_->AddEntryComputation(builder.Build());
121 SCOPED_TRACE(module_->ToString());
122
123 bool ssa_form = GetParam();
124 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
125
126 // Each instruction should define a single value.
127 EXPECT_EQ(analysis.values().size(), 3);
128 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
129 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
130 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
131
132 // Verify the positions of the values. These positions are all trivial because
133 // there are no instructions which forward values.
134 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
135 UnorderedElementsAre(HloPosition{constant1, {}}));
136 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
137 UnorderedElementsAre(HloPosition{constant2, {}}));
138 EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
139 UnorderedElementsAre(HloPosition{add, {}}));
140
141 // Verify the uses of the values.
142 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
143 UnorderedElementsAre(HloUse{add, 0, {}}));
144 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
145 UnorderedElementsAre(HloUse{add, 1, {}}));
146 EXPECT_TRUE(analysis.GetValueDefinedAt(add).GetUses().empty());
147
148 // Verify liveout values from the module.
149 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
150 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
151 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
152 }
153
TEST_P(HloDataflowAnalysisTest,TupleAndGtes)154 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
155 // Verify the dataflow through a Tuple and GetTupleElement instructions.
156 auto builder = HloComputation::Builder(TestName());
157 auto param0 = builder.AddInstruction(
158 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
159 auto param1 = builder.AddInstruction(
160 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
161 auto tuple =
162 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
163 auto gte0 = builder.AddInstruction(
164 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
165 auto gte1 = builder.AddInstruction(
166 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
167 auto add = builder.AddInstruction(
168 HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
169 module_->AddEntryComputation(builder.Build());
170 SCOPED_TRACE(module_->ToString());
171
172 bool ssa_form = GetParam();
173 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
174
175 // The two params, tuple, and add should each define one value.
176 EXPECT_EQ(analysis.values().size(), 4);
177
178 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
179 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
180 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
181 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
182 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
183 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
184 EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
185 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
186
187 // Verify the positions of the values.
188 EXPECT_THAT(
189 analysis.GetValueDefinedAt(param0).positions(),
190 UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
191 HloPosition{gte0, {}}));
192 EXPECT_THAT(
193 analysis.GetValueDefinedAt(param1).positions(),
194 UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
195 HloPosition{gte1, {}}));
196 EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
197 UnorderedElementsAre(HloPosition{tuple, {}}));
198
199 // Verify uses. Of interest is that a GetTupleElement instruction is only a
200 // use of the top-level value in the tuple operand.
201 EXPECT_THAT(analysis.GetValueDefinedAt(param0).GetUses(),
202 UnorderedElementsAre(HloUse{add, 0, {}}));
203 EXPECT_THAT(analysis.GetValueDefinedAt(param1).GetUses(),
204 UnorderedElementsAre(HloUse{add, 1, {}}));
205 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).GetUses(),
206 UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
207 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
208 }
209
TEST_P(HloDataflowAnalysisTest,NestedTuple)210 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
211 // Verify the dataflow through a nested tuple.
212 auto builder = HloComputation::Builder(TestName());
213 auto constant1 = builder.AddInstruction(
214 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
215 auto constant2 = builder.AddInstruction(
216 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
217 auto tuple = builder.AddInstruction(
218 HloInstruction::CreateTuple({constant1, constant2}));
219 auto nested_tuple = builder.AddInstruction(
220 HloInstruction::CreateTuple({tuple, tuple, constant1}));
221 auto gte_tuple = builder.AddInstruction(
222 HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
223 auto gte_out = builder.AddInstruction(
224 HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
225 module_->AddEntryComputation(builder.Build());
226 SCOPED_TRACE(module_->ToString());
227
228 bool ssa_form = GetParam();
229 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
230
231 EXPECT_EQ(analysis.values().size(), 4);
232
233 // Verify positions and uses.
234 EXPECT_THAT(
235 analysis.GetValueDefinedAt(constant1).positions(),
236 UnorderedElementsAre(
237 HloPosition{constant1, {}}, HloPosition{tuple, {0}},
238 HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
239 HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
240 HloPosition{gte_out, {}}));
241 // Constant values should have only a single use, which is the root of the
242 // computation.
243 EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).GetUses(),
244 UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
245 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).GetUses().empty());
246
247 // The top-level tuple values are used in GTE instructions.
248 EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).GetUses(),
249 UnorderedElementsAre(HloUse{gte_out, 0, {}}));
250 EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).GetUses(),
251 UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
252
253 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
254 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
255 EXPECT_FALSE(
256 analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
257 EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
258 .live_out_of_module());
259 }
260
TEST_P(HloDataflowAnalysisTest,SingleCall)261 TEST_P(HloDataflowAnalysisTest, SingleCall) {
262 // Test a single call of a subcomputation. The subcomputation adds its two
263 // array-shaped parameters.
264 auto subbuilder = HloComputation::Builder("Subcomputation");
265 auto subparam0 = subbuilder.AddInstruction(
266 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
267 auto subparam1 = subbuilder.AddInstruction(
268 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
269 auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
270 scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
271 HloComputation* called_computation =
272 module_->AddEmbeddedComputation(subbuilder.Build());
273
274 auto builder = HloComputation::Builder(TestName());
275 auto constant1 = builder.AddInstruction(
276 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
277 auto constant2 = builder.AddInstruction(
278 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
279 auto call = builder.AddInstruction(HloInstruction::CreateCall(
280 scalar_shape_, {constant1, constant2}, called_computation));
281 module_->AddEntryComputation(builder.Build());
282 SCOPED_TRACE(module_->ToString());
283
284 bool ssa_form = GetParam();
285 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
286
287 EXPECT_EQ(analysis.values().size(), 3);
288
289 // The parameters of the subcomputation and the call instruction itself should
290 // not define values. Their values flow from elsewhere.
291 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
292 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
293 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
294 EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
295 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
296 EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
297
298 EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
299 analysis.GetValueDefinedAt(constant1));
300 EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
301 analysis.GetValueDefinedAt(constant2));
302 EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
303
304 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
305 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
306 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
307 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
308
309 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
310 }
311
TEST_P(HloDataflowAnalysisTest,NestedCalls)312 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
313 // Test a module with nested computations. HLO is:
314 //
315 // F32[] inner_computation(F32[] %param0, F32[] %param1):
316 // %add = Add(%param0, %param1)
317 //
318 // F32[] outer_computation((F32[] %param0, F32[] %param1):
319 // ;; Note that parameters are interchanged in the call.
320 // %nested_call = Call(inner_computation, {%param1, %param0})
321 //
322 // F32[] entry:
323 // %constant1 = Constant(1.0)
324 // %constant2 = Constant(2.0)
325 // %call = Call(outer_computation, {%constant1, %constant2})
326 //
327 auto inner_builder = HloComputation::Builder("InnerComputation");
328 auto inner_param0 = inner_builder.AddInstruction(
329 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
330 auto inner_param1 = inner_builder.AddInstruction(
331 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
332 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
333 scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
334 HloComputation* inner_computation =
335 module_->AddEmbeddedComputation(inner_builder.Build());
336
337 auto outer_builder = HloComputation::Builder("OuterComputation");
338 auto outer_param0 = outer_builder.AddInstruction(
339 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
340 auto outer_param1 = outer_builder.AddInstruction(
341 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
342 // Swizzle parameters.
343 auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
344 scalar_shape_, {outer_param1, outer_param0}, inner_computation));
345 HloComputation* outer_computation =
346 module_->AddEmbeddedComputation(outer_builder.Build());
347
348 auto builder = HloComputation::Builder(TestName());
349 auto constant1 = builder.AddInstruction(
350 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
351 auto constant2 = builder.AddInstruction(
352 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
353 auto call = builder.AddInstruction(HloInstruction::CreateCall(
354 scalar_shape_, {constant1, constant2}, outer_computation));
355 module_->AddEntryComputation(builder.Build());
356 SCOPED_TRACE(module_->ToString());
357
358 bool ssa_form = GetParam();
359 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
360
361 // Only three values should be defined. Most instructions just pass through
362 // their operand values.
363 EXPECT_EQ(analysis.values().size(), 3);
364
365 // Verify that the uses of the constants are properly swizzled by parameter
366 // permutation in nested_call.
367 EXPECT_THAT(
368 analysis.GetValueDefinedAt(constant1).GetUses(),
369 UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
370 HloUse{add, 1, {}}));
371 EXPECT_THAT(
372 analysis.GetValueDefinedAt(constant2).GetUses(),
373 UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
374 HloUse{add, 0, {}}));
375
376 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
377 }
378
TEST_P(HloDataflowAnalysisTest,SingleWhile)379 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
380 // Test a simple single while instruction. The while body includes a
381 // pass-through value. HLO:
382 //
383 // body((F32[], F32[]) %tuple_param):
384 // %add = Add(%tuple_param{0}, %tuple_param{1})
385 // return Tuple(%tuple_param{0}, %add)
386 //
387 // condition((F32[], F32[]) %tuple_param):
388 // return Constant(false)
389 //
390 // entry:
391 // %constant1 = Constant(1.0)
392 // %constant2 = Constant(2.0)
393 // %tuple = Tuple(%constant1, %constant2)
394 // return While(%tuple, body, condition)
395 //
396 const Shape tuple_shape =
397 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
398
399 // Element 0 passes transparently through the body.
400 auto body_builder = HloComputation::Builder("body");
401 auto body_param = body_builder.AddInstruction(
402 HloInstruction::CreateParameter(0, tuple_shape, "param"));
403 auto body_element_0 = body_builder.AddInstruction(
404 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
405 auto body_element_1 = body_builder.AddInstruction(
406 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
407 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
408 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
409 auto body_root = body_builder.AddInstruction(
410 HloInstruction::CreateTuple({body_element_0, add}));
411 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
412
413 // Condition computation trivially returns a constant "false".
414 auto cond_builder = HloComputation::Builder("condition");
415 auto cond_param = cond_builder.AddInstruction(
416 HloInstruction::CreateParameter(0, tuple_shape, "param"));
417 auto cond_constant = cond_builder.AddInstruction(
418 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
419 HloComputation* condition =
420 module_->AddEmbeddedComputation(cond_builder.Build());
421
422 auto builder = HloComputation::Builder(TestName());
423 auto constant1 = builder.AddInstruction(
424 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
425 auto constant2 = builder.AddInstruction(
426 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
427 auto tuple = builder.AddInstruction(
428 HloInstruction::CreateTuple({constant1, constant2}));
429 auto xla_while = builder.AddInstruction(
430 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
431 module_->AddEntryComputation(builder.Build());
432 SCOPED_TRACE(module_->ToString());
433
434 bool ssa_form = GetParam();
435 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
436
437 EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
438
439 if (ssa_form) {
440 // Element 0 of the tuple passed through the body so no phi value is
441 // defined.
442 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
443 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
444 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
445
446 // Element 1 of the tuple should be a phi value.
447 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
448 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
449 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
450 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
451 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
452 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
453
454 EXPECT_THAT(
455 analysis.GetValueDefinedAt(constant1).GetUses(),
456 UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
457 HloUse{xla_while, 0, {0}}));
458
459 // Constant1 passes through the body and out of the module.
460 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
461 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
462 .live_out_of_module());
463
464 EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
465 } else {
466 // While instruction and subcomputation parameters should not define values
467 // in non-ssa form.
468 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
469 EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
470 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
471 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
472 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
473 EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
474
475 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
476 EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
477 }
478 }
479
TEST_P(HloDataflowAnalysisTest,SequentialWhiles)480 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
481 // Test sequential while instructions. The while body includes a
482 // pass-through value. HLO:
483 //
484 // body((F32[], F32[]) %tuple_param):
485 // %add = Add(%tuple_param{0}, %tuple_param{1})
486 // return Tuple(%tuple_param{0}, %add)
487 //
488 // condition((F32[], F32[]) %tuple_param):
489 // return Constant(false)
490 //
491 // entry:
492 // %constant1 = Constant(1.0)
493 // %constant2 = Constant(2.0)
494 // %tuple = Tuple(%constant1, %constant2)
495 // %while0 = While(%tuple, body, condition)
496 // %while1 = While(%while0, body, condition)
497 // return While(%while1, body, condition)
498 //
499 const Shape tuple_shape =
500 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
501
502 // Element 0 passes transparently through the body.
503 auto body_builder = HloComputation::Builder("body");
504 auto body_param = body_builder.AddInstruction(
505 HloInstruction::CreateParameter(0, tuple_shape, "param"));
506 auto body_element_0 = body_builder.AddInstruction(
507 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
508 auto body_element_1 = body_builder.AddInstruction(
509 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
510 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
511 scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
512 body_builder.AddInstruction(
513 HloInstruction::CreateTuple({body_element_0, add}));
514 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
515
516 auto cond_builder = HloComputation::Builder("condition");
517 cond_builder.AddInstruction(
518 HloInstruction::CreateParameter(0, tuple_shape, "param"));
519 cond_builder.AddInstruction(
520 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
521 HloComputation* condition =
522 module_->AddEmbeddedComputation(cond_builder.Build());
523
524 auto builder = HloComputation::Builder(TestName());
525 auto constant1 = builder.AddInstruction(
526 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
527 auto constant2 = builder.AddInstruction(
528 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
529 auto tuple = builder.AddInstruction(
530 HloInstruction::CreateTuple({constant1, constant2}));
531 auto xla_while0 = builder.AddInstruction(
532 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
533 auto xla_while1 = builder.AddInstruction(
534 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
535 auto xla_while2 = builder.AddInstruction(
536 HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
537 module_->AddEntryComputation(builder.Build());
538 SCOPED_TRACE(module_->ToString());
539
540 bool ssa_form = GetParam();
541 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
542
543 // Element 0 is passed through all the while instructions and out of the
544 // module..
545 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
546 analysis.GetValueDefinedAt(constant1));
547 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
548 analysis.GetValueDefinedAt(constant1));
549 EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
550 analysis.GetValueDefinedAt(constant1));
551 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
552 }
553
TEST_P(HloDataflowAnalysisTest,MultiLevelNestedWhile)554 TEST_P(HloDataflowAnalysisTest, MultiLevelNestedWhile) {
555 // Test nested while instructions. The level0 body (most inner while) and
556 // level1 body pass through the parameter, while level2 (most outer while)
557 // modifies it.
558 //
559 // level0_body((F32[]) %tuple_param):
560 // return Tuple(%tuple_param{0})
561 //
562 // level1_body((F32[]) %tuple_param):
563 // return While(%tuple_param{0}), body=level0
564 //
565 // level2_body((F32[]) %tuple_param):
566 // while = While(%tuple_param{0}), body=level1
567 //. return negate(%while{0})
568 //
569 // entry:
570 // %constant = Constant(1.0)
571 // %tuple = Tuple(%constant)
572 // return While(%tuple), body=level2
573 //
574 const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
575 auto cond_builder = HloComputation::Builder("condition");
576 cond_builder.AddInstruction(
577 HloInstruction::CreateParameter(0, tuple_shape, "param"));
578 cond_builder.AddInstruction(
579 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
580 HloComputation* condition =
581 module_->AddEmbeddedComputation(cond_builder.Build());
582
583 // level 0 passes transparently through the body.
584 auto level0_builder = HloComputation::Builder("level0_body");
585 auto level0_param = level0_builder.AddInstruction(
586 HloInstruction::CreateParameter(0, tuple_shape, "param"));
587 auto level0_element_0 = level0_builder.AddInstruction(
588 HloInstruction::CreateGetTupleElement(scalar_shape_, level0_param, 0));
589 auto level0_root = level0_builder.AddInstruction(
590 HloInstruction::CreateTuple({level0_element_0}));
591 HloComputation* level0_body =
592 module_->AddEmbeddedComputation(level0_builder.Build());
593
594 // Element 1 passes transparently through the body.
595 auto level1_builder = HloComputation::Builder("level1_body");
596 auto level1_param = level1_builder.AddInstruction(
597 HloInstruction::CreateParameter(0, tuple_shape, "param"));
598 auto level1_root = level1_builder.AddInstruction(HloInstruction::CreateWhile(
599 tuple_shape, condition, level0_body, level1_param));
600 HloComputation* level1_body =
601 module_->AddEmbeddedComputation(level1_builder.Build());
602
603 // Element 1 passes transparently through the body.
604 auto level2_builder = HloComputation::Builder("level2_body");
605 auto level2_param = level2_builder.AddInstruction(
606 HloInstruction::CreateParameter(0, tuple_shape, "param"));
607 auto level2_while = level2_builder.AddInstruction(HloInstruction::CreateWhile(
608 tuple_shape, condition, level1_body, level2_param));
609 auto level2_element_0 = level2_builder.AddInstruction(
610 HloInstruction::CreateGetTupleElement(scalar_shape_, level2_while, 0));
611 auto negate = level2_builder.AddInstruction(HloInstruction::CreateUnary(
612 scalar_shape_, HloOpcode::kNegate, level2_element_0));
613 level2_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
614 HloComputation* level2_body =
615 module_->AddEmbeddedComputation(level2_builder.Build());
616
617 auto builder = HloComputation::Builder(TestName());
618 auto constant1 = builder.AddInstruction(
619 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
620 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
621 builder.AddInstruction(
622 HloInstruction::CreateWhile(tuple_shape, condition, level2_body, tuple));
623 module_->AddEntryComputation(builder.Build());
624 SCOPED_TRACE(module_->ToString());
625
626 bool ssa_form = GetParam();
627 if (!ssa_form) {
628 return;
629 }
630 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
631
632 // Phi node on inner parameters and roots should have been eliminated.
633 EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_param, /*index=*/{0}));
634 EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_param, /*index=*/{0}));
635 EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_root, /*index=*/{0}));
636 EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_root, /*index=*/{0}));
637 EXPECT_TRUE(analysis.ValueIsDefinedAt(level2_param, /*index=*/{0}));
638 EXPECT_EQ(HloValuesAt(level1_param, /*index=*/{0}),
639 HloValuesAt(level2_param, /*index=*/{0}));
640 EXPECT_EQ(HloValuesAt(level0_param, /*index=*/{0}),
641 HloValuesAt(level2_param, /*index=*/{0}));
642 EXPECT_EQ(HloValuesAt(level1_root, /*index=*/{0}),
643 HloValuesAt(level2_param, /*index=*/{0}));
644 EXPECT_EQ(HloValuesAt(level0_root, /*index=*/{0}),
645 HloValuesAt(level2_param, /*index=*/{0}));
646 }
647
TEST_P(HloDataflowAnalysisTest,NestedWhiles)648 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
649 // Test nested while instructions. The inner body passes through element 0 of
650 // its parameter, and the outer body passes through element 1. HLO:
651 //
652 // inner_body((F32[], F32[]) %tuple_param):
653 // %add = Add(%tuple_param{0}, %tuple_param{1})
654 // return Tuple(%tuple_param{0}, %add)
655 //
656 // outer_body((F32[], F32[]) %tuple_param):
657 // %negate = Negate(%tuple_param{0})
658 // %tuple = Tuple(%negate, %tuple_param{1})
659 // return While(%tuple, inner_body, condition)
660 //
661 // entry:
662 // %constant1 = Constant(1.0)
663 // %constant2 = Constant(2.0)
664 // %tuple = Tuple(%constant1, %constant2)
665 // return While(%tuple, outer_body, condition)
666 //
667 const Shape tuple_shape =
668 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
669
670 auto cond_builder = HloComputation::Builder("condition");
671 cond_builder.AddInstruction(
672 HloInstruction::CreateParameter(0, tuple_shape, "param"));
673 cond_builder.AddInstruction(
674 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
675 HloComputation* condition =
676 module_->AddEmbeddedComputation(cond_builder.Build());
677
678 // Element 0 passes transparently through the body.
679 auto inner_builder = HloComputation::Builder("inner_body");
680 auto inner_param = inner_builder.AddInstruction(
681 HloInstruction::CreateParameter(0, tuple_shape, "param"));
682 auto inner_element_0 = inner_builder.AddInstruction(
683 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
684 auto inner_element_1 = inner_builder.AddInstruction(
685 HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
686 auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
687 scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
688 inner_builder.AddInstruction(
689 HloInstruction::CreateTuple({inner_element_0, add}));
690 HloComputation* inner_body =
691 module_->AddEmbeddedComputation(inner_builder.Build());
692
693 // Element 1 passes transparently through the body.
694 auto outer_builder = HloComputation::Builder("outer_body");
695 auto outer_param = outer_builder.AddInstruction(
696 HloInstruction::CreateParameter(0, tuple_shape, "param"));
697 auto outer_element_0 = outer_builder.AddInstruction(
698 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
699 auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
700 scalar_shape_, HloOpcode::kNegate, outer_element_0));
701 auto outer_element_1 = outer_builder.AddInstruction(
702 HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
703 auto outer_tuple = outer_builder.AddInstruction(
704 HloInstruction::CreateTuple({negate, outer_element_1}));
705 auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
706 tuple_shape, condition, inner_body, outer_tuple));
707 HloComputation* outer_body =
708 module_->AddEmbeddedComputation(outer_builder.Build());
709
710 auto builder = HloComputation::Builder(TestName());
711 auto constant1 = builder.AddInstruction(
712 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
713 auto constant2 = builder.AddInstruction(
714 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
715 auto tuple = builder.AddInstruction(
716 HloInstruction::CreateTuple({constant1, constant2}));
717 auto entry_while = builder.AddInstruction(
718 HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
719 module_->AddEntryComputation(builder.Build());
720 SCOPED_TRACE(module_->ToString());
721
722 bool ssa_form = GetParam();
723 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
724
725 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
726 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate)));
727 if (ssa_form) {
728 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
729 EXPECT_TRUE(
730 analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
731
732 // Element 0 of the nested while is %negate.
733 EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
734 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
735 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate)));
736 // Element 1 is a phi value (join of %add and %constant2).
737 EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
738 EXPECT_TRUE(
739 analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
740
741 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
742 EXPECT_TRUE(
743 analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
744
745 EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
746 EXPECT_TRUE(
747 analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
748 } else {
749 EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
750 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
751 &analysis.GetValueDefinedAt(constant2)));
752
753 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
754 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate)));
755 EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
756 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
757 &analysis.GetValueDefinedAt(constant2)));
758
759 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
760 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate),
761 &analysis.GetValueDefinedAt(constant1)));
762 EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
763 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
764 &analysis.GetValueDefinedAt(constant2)));
765 }
766 }
767
TEST_P(HloDataflowAnalysisTest,SwizzlingWhileSharedInput)768 TEST_P(HloDataflowAnalysisTest, SwizzlingWhileSharedInput) {
769 // Test a while instruction with a body which permutes it's tuple parameter
770 // elements. HLO:
771 //
772 // body((F32[], F32[]) %tuple_param):
773 // return Tuple(%tuple_param{1}, %tuple_param{0})
774 //
775 // condition((F32[], F32[]) %tuple_param):
776 // return Constant(false)
777 //
778 // entry:
779 // %constant1 = Constant(1.0)
780 // %tuple = Tuple(%constant1, %constant1)
781 // return While(%tuple, body, condition)
782 //
783 const Shape tuple_shape =
784 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
785
786 auto body_builder = HloComputation::Builder("body");
787 auto body_param = body_builder.AddInstruction(
788 HloInstruction::CreateParameter(0, tuple_shape, "param"));
789 auto body_element_0 = body_builder.AddInstruction(
790 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
791 auto body_element_1 = body_builder.AddInstruction(
792 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
793 body_builder.AddInstruction(
794 HloInstruction::CreateTuple({body_element_1, body_element_0}));
795 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
796
797 auto cond_builder = HloComputation::Builder("condition");
798 cond_builder.AddInstruction(
799 HloInstruction::CreateParameter(0, tuple_shape, "param"));
800 cond_builder.AddInstruction(
801 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
802 HloComputation* condition =
803 module_->AddEmbeddedComputation(cond_builder.Build());
804
805 auto builder = HloComputation::Builder(TestName());
806 auto constant1 = builder.AddInstruction(
807 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
808 auto tuple = builder.AddInstruction(
809 HloInstruction::CreateTuple({constant1, constant1}));
810 builder.AddInstruction(
811 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
812 module_->AddEntryComputation(builder.Build());
813 SCOPED_TRACE(module_->ToString());
814
815 bool ssa_form = GetParam();
816 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
817 EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
818 }
819
TEST_P(HloDataflowAnalysisTest,SwizzlingWhile)820 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
821 // Test a while instruction with a body which permutes it's tuple parameter
822 // elements. HLO:
823 //
824 // body((F32[], F32[]) %tuple_param):
825 // return Tuple(%tuple_param{1}, %tuple_param{0})
826 //
827 // condition((F32[], F32[]) %tuple_param):
828 // return Constant(false)
829 //
830 // entry:
831 // %constant1 = Constant(1.0)
832 // %constant2 = Constant(2.0)
833 // %tuple = Tuple(%constant1, %constant2)
834 // return While(%tuple, body, condition)
835 //
836 const Shape tuple_shape =
837 ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
838
839 auto body_builder = HloComputation::Builder("body");
840 auto body_param = body_builder.AddInstruction(
841 HloInstruction::CreateParameter(0, tuple_shape, "param"));
842 auto body_element_0 = body_builder.AddInstruction(
843 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
844 auto body_element_1 = body_builder.AddInstruction(
845 HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
846 body_builder.AddInstruction(
847 HloInstruction::CreateTuple({body_element_1, body_element_0}));
848 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
849
850 auto cond_builder = HloComputation::Builder("condition");
851 auto cond_param = cond_builder.AddInstruction(
852 HloInstruction::CreateParameter(0, tuple_shape, "param"));
853 cond_builder.AddInstruction(
854 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
855 HloComputation* condition =
856 module_->AddEmbeddedComputation(cond_builder.Build());
857
858 auto builder = HloComputation::Builder(TestName());
859 auto constant1 = builder.AddInstruction(
860 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
861 auto constant2 = builder.AddInstruction(
862 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
863 auto tuple = builder.AddInstruction(
864 HloInstruction::CreateTuple({constant1, constant2}));
865 auto xla_while = builder.AddInstruction(
866 HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
867 module_->AddEntryComputation(builder.Build());
868 SCOPED_TRACE(module_->ToString());
869
870 bool ssa_form = GetParam();
871 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
872
873 if (ssa_form) {
874 // Element 0 and 1 in the while should both be phi values.
875 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
876 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
877 EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
878 EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
879
880 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
881 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
882 EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
883 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
884
885 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
886 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
887 EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
888 EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
889
890 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
891 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
892 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
893 .live_out_of_module());
894 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
895 .live_out_of_module());
896 EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
897 .live_out_of_module());
898 } else {
899 // Elements 0 and 1 have both constants as reaching definitions.
900 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
901 UnorderedElementsAre(&analysis.GetValueDefinedAt(constant1),
902 &analysis.GetValueDefinedAt(constant2)));
903 EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
904 UnorderedElementsAre(&analysis.GetValueDefinedAt(constant1),
905 &analysis.GetValueDefinedAt(constant2)));
906 EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
907 EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
908 }
909 }
910
TEST_P(HloDataflowAnalysisTest,ArraySelect)911 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
912 // Test a kSelect of an array value.
913 auto builder = HloComputation::Builder(TestName());
914 auto pred = builder.AddInstruction(
915 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
916 auto constant1 = builder.AddInstruction(
917 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
918 auto constant2 = builder.AddInstruction(
919 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
920 auto select = builder.AddInstruction(HloInstruction::CreateTernary(
921 scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
922
923 module_->AddEntryComputation(builder.Build());
924 SCOPED_TRACE(module_->ToString());
925
926 bool ssa_form = GetParam();
927 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
928
929 EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
930 EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
931 EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
932 EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
933 }
934
TEST_P(HloDataflowAnalysisTest,BitcastDefinesValue)935 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
936 // Test the bitcast_defines_value flag to the dataflow analysis.
937 auto builder = HloComputation::Builder(TestName());
938 auto constant = builder.AddInstruction(
939 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
940 auto bitcast = builder.AddInstruction(
941 HloInstruction::CreateBitcast(scalar_shape_, constant));
942
943 module_->AddEntryComputation(builder.Build());
944 SCOPED_TRACE(module_->ToString());
945
946 bool ssa_form = GetParam();
947 {
948 const HloDataflowAnalysis& analysis =
949 RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
950
951 EXPECT_EQ(analysis.values().size(), 2);
952
953 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
954 EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
955 EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
956 EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
957 }
958 {
959 const HloDataflowAnalysis& analysis =
960 RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
961 EXPECT_EQ(analysis.values().size(), 1);
962
963 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
964 EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
965 EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
966 }
967 }
968
TEST_P(HloDataflowAnalysisTest,TupleCopy)969 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
970 // Test that a tuple-shaped copy only copies (defines) the top-level value.
971 auto builder = HloComputation::Builder(TestName());
972 auto param0 = builder.AddInstruction(
973 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
974 auto param1 = builder.AddInstruction(
975 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
976 auto tuple =
977 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
978 auto copy = builder.AddInstruction(
979 HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
980 module_->AddEntryComputation(builder.Build());
981 SCOPED_TRACE(module_->ToString());
982
983 bool ssa_form = GetParam();
984 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
985
986 EXPECT_EQ(analysis.values().size(), 4);
987
988 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
989 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
990 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
991 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
992 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
993 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
994 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
995 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
996
997 EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
998 UnorderedElementsAre(&analysis.GetValueDefinedAt(param0)));
999 EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
1000 UnorderedElementsAre(&analysis.GetValueDefinedAt(param1)));
1001 EXPECT_TRUE(
1002 analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
1003 }
1004
TEST_P(HloDataflowAnalysisTest,OptimizationBarrier)1005 TEST_P(HloDataflowAnalysisTest, OptimizationBarrier) {
1006 // Test that an optimization barrier is a nop.
1007 auto builder = HloComputation::Builder(TestName());
1008 auto param0 = builder.AddInstruction(
1009 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1010 auto param1 = builder.AddInstruction(
1011 HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
1012 auto tuple =
1013 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1014 auto barrier = builder.AddInstruction(HloInstruction::CreateUnary(
1015 tuple->shape(), HloOpcode::kOptimizationBarrier, tuple));
1016 module_->AddEntryComputation(builder.Build());
1017 SCOPED_TRACE(module_->ToString());
1018
1019 bool ssa_form = GetParam();
1020 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1021
1022 EXPECT_EQ(analysis.values().size(), 3);
1023
1024 EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
1025 EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
1026 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
1027 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
1028 EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
1029 EXPECT_FALSE(analysis.ValueIsDefinedAt(barrier, /*index=*/{}));
1030 EXPECT_FALSE(analysis.ValueIsDefinedAt(barrier, /*index=*/{0}));
1031 EXPECT_FALSE(analysis.ValueIsDefinedAt(barrier, /*index=*/{1}));
1032
1033 EXPECT_THAT(HloValuesAt(barrier, /*index=*/{0}),
1034 UnorderedElementsAre(&analysis.GetValueDefinedAt(param0)));
1035 EXPECT_THAT(HloValuesAt(barrier, /*index=*/{1}),
1036 UnorderedElementsAre(&analysis.GetValueDefinedAt(param1)));
1037 }
1038
TEST_P(HloDataflowAnalysisTest,CopyStartAndCopyDone)1039 TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
1040 // Test that a CopyDone forwards its operand tuple element at {0} to the
1041 // output.
1042 auto builder = HloComputation::Builder(TestName());
1043 auto constant = builder.AddInstruction(
1044 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1045 auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
1046 ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
1047 ShapeUtil::MakeShape(U32, {})}),
1048 constant));
1049 auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
1050 constant->shape(), HloOpcode::kCopyDone, copy_start));
1051 module_->AddEntryComputation(builder.Build());
1052 SCOPED_TRACE(module_->ToString());
1053
1054 bool ssa_form = GetParam();
1055 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1056
1057 EXPECT_EQ(analysis.values().size(), 4);
1058
1059 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{}));
1060 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0}));
1061 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
1062 EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2}));
1063 EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{}));
1064 EXPECT_THAT(
1065 HloValuesAt(copy_done, /*index=*/{}),
1066 UnorderedElementsAre(&analysis.GetValueDefinedAt(copy_start, {0})));
1067 EXPECT_TRUE(analysis.GetValueDefinedAt(copy_start, /*index=*/{0})
1068 .live_out_of_module());
1069 }
1070
TEST_P(HloDataflowAnalysisTest,AsyncOps)1071 TEST_P(HloDataflowAnalysisTest, AsyncOps) {
1072 std::string hlo_str = R"(
1073 HloModule module
1074
1075 ENTRY entry {
1076 p0 = f32[2,3] parameter(0)
1077 async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
1078 async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo"
1079 ROOT async-done = f32[2,3] custom-call-done(async-update), custom_call_target="foo"
1080 }
1081 )";
1082 TF_ASSERT_OK_AND_ASSIGN(
1083 module_, ParseAndReturnVerifiedModule(hlo_str, GetModuleConfigForTest()));
1084
1085 bool ssa_form = GetParam();
1086 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1087
1088 const HloInstruction* param =
1089 module_->entry_computation()->parameter_instruction(0);
1090 const HloInstruction* async_start =
1091 FindInstruction(module_.get(), "async-start");
1092 const HloInstruction* async_update =
1093 FindInstruction(module_.get(), "async-update");
1094 const HloInstruction* async_done =
1095 FindInstruction(module_.get(), "async-done");
1096 const HloInstruction* async_wrapped_instruction =
1097 async_start->async_wrapped_instruction();
1098
1099 EXPECT_TRUE(analysis.ValueIsDefinedAt(async_start, /*index=*/{}));
1100 EXPECT_FALSE(analysis.ValueIsDefinedAt(async_start, /*index=*/{0, 0}));
1101 EXPECT_FALSE(analysis.ValueIsDefinedAt(async_start, /*index=*/{1}));
1102 EXPECT_THAT(HloValuesAt(async_start, {1}),
1103 UnorderedElementsAre(
1104 &analysis.GetValueDefinedAt(async_wrapped_instruction, {})));
1105 EXPECT_TRUE(analysis.ValueIsDefinedAt(async_start, /*index=*/{2}));
1106 EXPECT_THAT(HloValuesAt(async_start, /*index=*/{0, 0}),
1107 UnorderedElementsAre(&analysis.GetValueDefinedAt(param, {})));
1108 EXPECT_TRUE(analysis.GetValueDefinedAt(async_wrapped_instruction, {})
1109 .live_out_of_module());
1110
1111 EXPECT_TRUE(analysis.ValueIsDefinedAt(async_update, /*index=*/{}));
1112 EXPECT_FALSE(analysis.ValueIsDefinedAt(async_update, /*index=*/{0, 0}));
1113 EXPECT_FALSE(analysis.ValueIsDefinedAt(async_update, /*index=*/{1}));
1114 EXPECT_FALSE(analysis.ValueIsDefinedAt(async_update, /*index=*/{2}));
1115 EXPECT_THAT(HloValuesAt(async_update, /*index=*/{0, 0}),
1116 UnorderedElementsAre(&analysis.GetValueDefinedAt(param, {})));
1117 EXPECT_THAT(HloValuesAt(async_update, /*index=*/{1}),
1118 UnorderedElementsAre(
1119 &analysis.GetValueDefinedAt(async_wrapped_instruction, {})));
1120
1121 EXPECT_FALSE(analysis.ValueIsDefinedAt(async_done, /*index=*/{}));
1122 EXPECT_THAT(HloValuesAt(async_done, /*index=*/{}),
1123 UnorderedElementsAre(
1124 &analysis.GetValueDefinedAt(async_wrapped_instruction, {})));
1125 }
1126
TEST_P(HloDataflowAnalysisTest,AsyncCall)1127 TEST_P(HloDataflowAnalysisTest, AsyncCall) {
1128 std::string hlo_str = R"(
1129 HloModule AsyncCall
1130
1131 %called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] {
1132 %param_0 = f32[4096]{0} parameter(0)
1133 %param_1 = f32[4096]{0} parameter(1)
1134 %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0)
1135 %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1)
1136 ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1)
1137 }
1138
1139 ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
1140 %a = f32[4096]{0} parameter(0)
1141 %b = f32[4096]{0} parameter(1)
1142 %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(f32[4096]{0} %a, f32[4096]{0} %b), to_apply=%called_computation
1143 %negate_2 = f32[4096]{0} negate(f32[4096]{0} %a)
1144 %async-update = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-update(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), to_apply=%called_computation
1145 %negate_3 = f32[4096]{0} negate(f32[4096]{0} %b)
1146 %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3)
1147 %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-update), to_apply=%called_computation
1148 ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done)
1149 }
1150 )";
1151 TF_ASSERT_OK_AND_ASSIGN(
1152 module_, ParseAndReturnVerifiedModule(hlo_str, GetModuleConfigForTest()));
1153
1154 bool ssa_form = GetParam();
1155 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1156
1157 const HloInstruction* a = FindInstruction(module_.get(), "a");
1158 const HloInstruction* b = FindInstruction(module_.get(), "b");
1159 const HloInstruction* async_done =
1160 FindInstruction(module_.get(), "async-done");
1161
1162 // For each of the async operations, ensure the called computation
1163 // parameter/root instructions have the same HloValues as the callees.
1164 for (std::string async_name : {"async-start", "async-update", "async-done"}) {
1165 const HloInstruction* async_op = FindInstruction(module_.get(), async_name);
1166 const HloComputation* called_computation =
1167 async_op->async_wrapped_instruction()->called_computations()[0];
1168 const HloInstruction* parameter0 =
1169 called_computation->parameter_instruction(0);
1170 EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter0));
1171 EXPECT_THAT(HloValuesAt(parameter0),
1172 UnorderedElementsAre(&analysis.GetValueDefinedAt(a)));
1173 const HloInstruction* parameter1 =
1174 called_computation->parameter_instruction(1);
1175 EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter1));
1176 EXPECT_THAT(HloValuesAt(parameter1),
1177 UnorderedElementsAre(&analysis.GetValueDefinedAt(b)));
1178 const HloInstruction* root = called_computation->root_instruction();
1179 EXPECT_TRUE(analysis.ValueIsDefinedAt(root));
1180 EXPECT_THAT(HloValuesAt(async_done),
1181 UnorderedElementsAre(&analysis.GetValueDefinedAt(root)));
1182 }
1183 }
1184
TEST_P(HloDataflowAnalysisTest,SendAndSendDone)1185 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
1186 // Test that a Send forwards its operand to the output tuple at {0}.
1187 auto builder = HloComputation::Builder(TestName());
1188 auto param = builder.AddInstruction(
1189 HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1190 auto token = builder.AddInstruction(HloInstruction::CreateToken());
1191 auto send = builder.AddInstruction(
1192 HloInstruction::CreateSend(param, token, /*channel_id=*/0));
1193 auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
1194 module_->AddEntryComputation(builder.Build());
1195 SCOPED_TRACE(module_->ToString());
1196
1197 bool ssa_form = GetParam();
1198 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1199
1200 EXPECT_EQ(analysis.values().size(), 6);
1201
1202 EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1203 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
1204 EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
1205 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
1206 EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
1207 EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
1208 EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
1209 UnorderedElementsAre(&analysis.GetValueDefinedAt(param)));
1210 }
1211
TEST_P(HloDataflowAnalysisTest,SetDimensionSizeForwardsValue)1212 TEST_P(HloDataflowAnalysisTest, SetDimensionSizeForwardsValue) {
1213 auto builder = HloComputation::Builder(TestName());
1214 auto param = builder.AddInstruction(
1215 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1216 auto size = builder.AddInstruction(
1217 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(3)));
1218 auto sds = builder.AddInstruction(
1219 HloInstruction::CreateSetDimensionSize(vector_shape_, param, size, 0));
1220
1221 module_->AddEntryComputation(builder.Build());
1222 SCOPED_TRACE(module_->ToString());
1223
1224 bool ssa_form = GetParam();
1225 {
1226 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1227 EXPECT_EQ(analysis.values().size(), 2);
1228
1229 EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1230 EXPECT_FALSE(analysis.ValueIsDefinedAt(sds));
1231 EXPECT_TRUE(analysis.GetValueDefinedAt(param).live_out_of_module());
1232 }
1233 }
1234
TEST_P(HloDataflowAnalysisTest,RecvAndRecvDone)1235 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
1236 // Test that a RecvDone forwards its operand tuple element at {0} to element
1237 // {0} of the output.
1238 auto builder = HloComputation::Builder(TestName());
1239 auto token = builder.AddInstruction(HloInstruction::CreateToken());
1240 auto recv = builder.AddInstruction(
1241 HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
1242 auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
1243 module_->AddEntryComputation(builder.Build());
1244 SCOPED_TRACE(module_->ToString());
1245
1246 bool ssa_form = GetParam();
1247 const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1248
1249 EXPECT_EQ(analysis.values().size(), 7);
1250
1251 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
1252 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
1253 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
1254 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
1255 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
1256 EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
1257 EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
1258 EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
1259 UnorderedElementsAre(&analysis.GetValueDefinedAt(recv, {0})));
1260 EXPECT_TRUE(
1261 analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
1262 }
1263
TEST_P(HloDataflowAnalysisTest,ElementwiseChainInterference)1264 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
1265 // A simple chain of elementwise operations. No values should interfere.
1266 //
1267 // param --> negate -> exp -> log
1268 //
1269 auto builder = HloComputation::Builder(TestName());
1270 auto param = builder.AddInstruction(
1271 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1272 auto negate = builder.AddInstruction(
1273 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1274 auto exp = builder.AddInstruction(
1275 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
1276 auto log = builder.AddInstruction(
1277 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
1278
1279 module_->AddEntryComputation(builder.Build());
1280 SCOPED_TRACE(module_->ToString());
1281 RunAnalysis(GetParam());
1282
1283 DependencyHloOrdering ordering(module_.get());
1284
1285 // No values should interfere.
1286 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1287 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1288 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
1289 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
1290 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
1291 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1292 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
1293 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
1294 EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
1295
1296 // Values should interfere with itself.
1297 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
1298 }
1299
TEST_P(HloDataflowAnalysisTest,MultipleEntryParameters_Sequential)1300 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
1301 // Two entry params, which interfere with each other.
1302 //
1303 // param0 --> negate ---------------\
1304 // param1 --> exp --> add
1305 auto builder = HloComputation::Builder(TestName());
1306 auto param0 = builder.AddInstruction(
1307 HloInstruction::CreateParameter(0, vector_shape_, "param0"));
1308 auto param1 = builder.AddInstruction(
1309 HloInstruction::CreateParameter(1, vector_shape_, "param1"));
1310 auto negate = builder.AddInstruction(
1311 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
1312 auto exp = builder.AddInstruction(
1313 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
1314 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1315 vector_shape_, HloOpcode::kAdd, negate, exp));
1316
1317 auto entry = module_->AddEntryComputation(builder.Build());
1318 SCOPED_TRACE(module_->ToString());
1319 RunAnalysis(GetParam());
1320
1321 HloSchedule schedule(module_.get());
1322 schedule.set_sequence(entry, {param0, negate, param1, exp, add});
1323 TF_ASSERT_OK(schedule.Verify());
1324 SequentialHloOrdering ordering(schedule);
1325
1326 // Entry parameters interfere as if they are defined simultaneously at
1327 // the very beginning.
1328 EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
1329 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
1330 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
1331 EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
1332 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
1333 EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
1334 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
1335 EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
1336
1337 // Negate and exp still interfere.
1338 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1339 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1340
1341 // But {negate, add} and {exp, add} don't interfere.
1342 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1343 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1344 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1345 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1346 }
1347
TEST_P(HloDataflowAnalysisTest,WhileParameters_Sequential)1348 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
1349 // Similar to MultipleEntryParameters_Sequential, but the parameter is of
1350 // while body computation. Body computation in the sequential order:
1351 //
1352 // %constant = Constant(...)
1353 // %exp = Exp(%constant)
1354 // %param = Param(0)
1355 // %add = Add(%param, %exp) ;; Root of body
1356 // %dead_constant = Constant(...)
1357 // %dead_negate = Negate(%dead_constant)
1358 //
1359 // %constant and its only use %exp are ordered before 'param'. However, the
1360 // %constant and %param values still interfere because the parameter is
1361 // considered live into the while body.
1362 //
1363 // Similarly, %dead_constant and %dead_negate are ordered after the root of
1364 // the body computation %add. However, %add is liveout of the computation so
1365 // %dead_constant and %add interfere.
1366 auto body_builder = HloComputation::Builder(TestName());
1367 auto body_param = body_builder.AddInstruction(
1368 HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
1369 auto constant = body_builder.AddInstruction(
1370 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1371 auto exp = body_builder.AddInstruction(
1372 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
1373 auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1374 scalar_shape_, HloOpcode::kAdd, exp, body_param));
1375 auto dead_constant = body_builder.AddInstruction(
1376 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1377 auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1378 scalar_shape_, HloOpcode::kNegate, dead_constant));
1379 HloComputation* body = module_->AddEmbeddedComputation(
1380 body_builder.Build(/*root_instruction=*/add));
1381
1382 auto cond_builder = HloComputation::Builder("condition");
1383 auto cond_param = cond_builder.AddInstruction(
1384 HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
1385 auto cond_constant = cond_builder.AddInstruction(
1386 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1387 HloComputation* condition =
1388 module_->AddEmbeddedComputation(cond_builder.Build());
1389
1390 auto builder = HloComputation::Builder(TestName());
1391 auto param = builder.AddInstruction(
1392 HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1393 auto xla_while = builder.AddInstruction(
1394 HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
1395
1396 auto entry = module_->AddEntryComputation(builder.Build());
1397 SCOPED_TRACE(module_->ToString());
1398 bool ssa_form = GetParam();
1399 RunAnalysis(ssa_form, /*bitcast_defines_value=*/false,
1400 /*run_dce=*/false);
1401
1402 HloSchedule schedule(module_.get());
1403 schedule.set_sequence(entry, {param, xla_while});
1404 schedule.set_sequence(condition, {cond_param, cond_constant});
1405 // Construct the order such that 'constant' and its use 'exp' are before
1406 // body_param.
1407 schedule.set_sequence(
1408 body, {constant, exp, body_param, add, dead_constant, dead_negate});
1409 TF_ASSERT_OK(schedule.Verify());
1410
1411 SequentialHloOrdering ordering(schedule);
1412
1413 // 'add' is live out of the body and will interfere with an later instructions
1414 // such as 'dead_constant' and 'dead_negate'.
1415 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
1416 EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
1417
1418 // The remaining checks test phi values defined by body and condition
1419 // parameters which only occur in the SSA form of the analysis.
1420 if (ssa_form) {
1421 // Though the ordering suggests 'constant' and 'param' should not interfere,
1422 // 'param' is live in and thus interferes with any earlier instruction of
1423 // the computation in the order (eg 'constant')'
1424 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
1425 EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
1426 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1427
1428 // The following values end up in the same buffer:
1429 // (1) the init value: 'param'
1430 // (2) the body parameter: 'body_param'
1431 // (3) the condition parameter: 'cond_param'
1432 // (4) the root value of the while body: 'add'
1433 // (5) the while value: 'xla_while'
1434 // None should interfere.
1435 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
1436 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
1437 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1438 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
1439
1440 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
1441 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1442 EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
1443
1444 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
1445 EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
1446
1447 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
1448 }
1449 }
1450
TEST_P(HloDataflowAnalysisTest,NonElementwiseOperand)1451 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
1452 // A chain of operations with two elementwise and one non-elementwise. The
1453 // elementwise op should not interfere with its operand, while the
1454 // non-elementwise op should interfere. Entry params always interfere.
1455 //
1456 // param --> exp -> negate -> reverse
1457 //
1458 auto builder = HloComputation::Builder(TestName());
1459 auto param = builder.AddInstruction(
1460 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1461 auto exp = builder.AddInstruction(
1462 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1463 auto negate = builder.AddInstruction(
1464 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
1465 auto reverse = builder.AddInstruction(
1466 HloInstruction::CreateReverse(vector_shape_, negate, {0}));
1467
1468 module_->AddEntryComputation(builder.Build());
1469 SCOPED_TRACE(module_->ToString());
1470 RunAnalysis(GetParam());
1471
1472 DependencyHloOrdering ordering(module_.get());
1473
1474 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1475 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1476 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
1477
1478 // Negate is elementwise, so doesn't interfere with its operand.
1479 // Reverse is non-elementwise, so does interfere with its operand.
1480 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1481 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
1482 }
1483
TEST_P(HloDataflowAnalysisTest,OverlappedValues)1484 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
1485 // Verify simultaneously live values interfere (exp and negate).
1486 //
1487 // param --> negate -> add
1488 // \---> exp -----/
1489 //
1490 auto builder = HloComputation::Builder(TestName());
1491 auto param = builder.AddInstruction(
1492 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1493 auto negate = builder.AddInstruction(
1494 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1495 auto exp = builder.AddInstruction(
1496 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1497 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1498 vector_shape_, HloOpcode::kAdd, negate, exp));
1499
1500 module_->AddEntryComputation(builder.Build());
1501 SCOPED_TRACE(module_->ToString());
1502 RunAnalysis(GetParam());
1503
1504 DependencyHloOrdering ordering(module_.get());
1505
1506 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1507 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
1508 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1509
1510 // Negate and exp interfere with each other, but not with add.
1511 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1512 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1513 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1514 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1515 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1516 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1517 }
1518
TEST_P(HloDataflowAnalysisTest,OverlappedValuesSequentialOrder)1519 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
1520 // Identical to the test OverlappedValue but using a sequential ordering of
1521 // HLO instructions.
1522 //
1523 // param --> negate -> add
1524 // \---> exp -----/
1525 //
1526 // Sequential order:
1527 // param, negate, exp, add
1528 //
1529 // Liveness is identical to the DependencyHloOrdering.
1530 auto builder = HloComputation::Builder(TestName());
1531 auto param = builder.AddInstruction(
1532 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1533 auto negate = builder.AddInstruction(
1534 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1535 auto exp = builder.AddInstruction(
1536 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1537 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1538 vector_shape_, HloOpcode::kAdd, negate, exp));
1539
1540 auto entry = module_->AddEntryComputation(builder.Build());
1541 SCOPED_TRACE(module_->ToString());
1542 RunAnalysis(GetParam());
1543
1544 HloSchedule schedule(module_.get());
1545 schedule.set_sequence(entry, {param, negate, exp, add});
1546 TF_ASSERT_OK(schedule.Verify());
1547 SequentialHloOrdering ordering(schedule);
1548
1549 EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1550 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1551 EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1552
1553 // Negate and exp interfere with each other, but not with add.
1554 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1555 EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1556 EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1557 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1558 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1559 EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1560 }
1561
TEST_P(HloDataflowAnalysisTest,EmbeddedComputationInterference)1562 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
1563 // Test MayInterfere() for embedded computation, specifically the interference
1564 // of values in different computations.
1565 //
1566 // embedded_computation:
1567 // %embedded_param = Param(0)
1568 // %embedded_log = Log(%embedded_param)
1569 //
1570 // entry computation:
1571 // %param = Param(0)
1572 // %negate = Negate(%param)
1573 // %exp = Negate(%exp)
1574 // %call = Call(embedded_computation, {%exp})
1575 // %add = Add(%negate, %call)
1576 //
1577 // Note %negate is live across the call and should interfere with all values
1578 // in the embedded computation.
1579 auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
1580 auto embedded_param = embedded_builder.AddInstruction(
1581 HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
1582 auto embedded_log =
1583 embedded_builder.AddInstruction(HloInstruction::CreateUnary(
1584 vector_shape_, HloOpcode::kLog, embedded_param));
1585 auto embedded_computation =
1586 module_->AddEmbeddedComputation(embedded_builder.Build());
1587
1588 auto builder = HloComputation::Builder(TestName());
1589 auto param = builder.AddInstruction(
1590 HloInstruction::CreateParameter(0, vector_shape_, "param"));
1591 auto negate = builder.AddInstruction(
1592 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1593 auto exp = builder.AddInstruction(
1594 HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1595 auto call = builder.AddInstruction(
1596 HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
1597 builder.AddInstruction(HloInstruction::CreateBinary(
1598 vector_shape_, HloOpcode::kAdd, negate, call));
1599 module_->AddEntryComputation(builder.Build());
1600 SCOPED_TRACE(module_->ToString());
1601 RunAnalysis(GetParam());
1602
1603 DependencyHloOrdering ordering(module_.get());
1604
1605 // Exp only use is the call so it should not interfere with values inside
1606 // the embedded computation.
1607 EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
1608
1609 // Negate is live across the call and should interfere with values in the
1610 // embedded computation
1611 EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
1612 }
1613
TEST_P(HloDataflowAnalysisTest,GetFlattenedValueSet)1614 TEST_P(HloDataflowAnalysisTest, GetFlattenedValueSet) {
1615 const char* hlo_text = R"(
1616 HloModule test_aliasing_module
1617
1618 ENTRY root {
1619 param = s32[1000] parameter(0)
1620 p0 = s32[1000] copy(param)
1621 p1 = s32[1000] copy(param)
1622 ROOT t = (s32[1000], s32[1000]) tuple(p0, p1)
1623 })";
1624 TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
1625 auto entry = module_->entry_computation();
1626 entry->GetInstructionWithName("t");
1627 auto& dataflow_analysis = RunAnalysis(GetParam());
1628 auto set = dataflow_analysis.GetFlattenedValueSet(
1629 entry->GetInstructionWithName("t"));
1630 EXPECT_EQ(set.values().size(), 3);
1631 }
1632
TEST_P(HloDataflowAnalysisTest,ConditionalWithIdentity)1633 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
1634 // Test conditional with identity computations in both true and false cases.
1635 //
1636 // true_computation(F32[] %true_param):
1637 // return %true_param
1638 //
1639 // false_computation(F32[] %false_param):
1640 // return %false_param
1641 //
1642 // entry:
1643 // %pred = Constant(true)
1644 // %constant1 = Constant(56.0)
1645 // %constant2 = Constant(12.0)
1646 // return Conditional(%pred, %constant1, true_computation,
1647 // %constant2, false_computation)
1648
1649 auto true_builder = HloComputation::Builder(TestName() + "_true");
1650 auto true_param = true_builder.AddInstruction(
1651 HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
1652 HloComputation* true_computation =
1653 module_->AddEmbeddedComputation(true_builder.Build());
1654
1655 auto false_builder = HloComputation::Builder(TestName() + "_false");
1656 auto false_param = false_builder.AddInstruction(
1657 HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
1658 HloComputation* false_computation =
1659 module_->AddEmbeddedComputation(false_builder.Build());
1660
1661 auto builder = HloComputation::Builder(TestName());
1662 auto pred = builder.AddInstruction(
1663 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1664 auto constant1 = builder.AddInstruction(
1665 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1666 auto constant2 = builder.AddInstruction(
1667 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1668 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1669 scalar_shape_, pred, constant1, true_computation, constant2,
1670 false_computation));
1671 module_->AddEntryComputation(builder.Build());
1672 SCOPED_TRACE(module_->ToString());
1673
1674 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1675
1676 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1677 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1678 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1679
1680 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1681 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1682
1683 EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1684 analysis.GetValueDefinedAt(constant1));
1685 EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1686 analysis.GetValueDefinedAt(constant2));
1687
1688 EXPECT_THAT(analysis.GetValueDefinedAt(pred).GetUses(),
1689 ElementsAre(HloUse{conditional, 0, {}}));
1690 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
1691 ElementsAre(HloUse{conditional, 1, {}}));
1692 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
1693 ElementsAre(HloUse{conditional, 2, {}}));
1694
1695 bool ssa_form = GetParam();
1696 if (ssa_form) {
1697 EXPECT_EQ(analysis.values().size(), 4);
1698 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1699 } else {
1700 EXPECT_EQ(analysis.values().size(), 3);
1701 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1702 EXPECT_THAT(HloValuesAt(conditional),
1703 UnorderedElementsAre(&analysis.GetValueDefinedAt(constant1),
1704 &analysis.GetValueDefinedAt(constant2)));
1705 }
1706 }
1707
TEST_P(HloDataflowAnalysisTest,ConditionalTakingTupleOperand)1708 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
1709 // Test conditional with true and false computations taking a tuple operand.
1710 //
1711 // true_computation((F32[], F32[]) %true_param):
1712 // %true_x = GetTupleElement(%true_param, 0)
1713 // %true_y = GetTupleElement(%true_param, 1)
1714 // return Add(%true_x, %true_y)
1715 //
1716 // false_computation((F32[], F32[]) %false_param):
1717 // %false_x = GetTupleElement(%false_param, 0)
1718 // %false_y = GetTupleElement(%false_param, 1)
1719 // return Subtract(%false_x, %false_y)
1720 //
1721 // entry:
1722 // %pred = Constant(true)
1723 // %constant1 = Constant(56.0)
1724 // %constant2 = Constant(12.0)
1725 // %tuple_operand = Tuple(%constant1, %constant2)
1726 // return Conditional(%pred, %tuple_operand, true_computation,
1727 // %tuple_operand, false_computation)
1728
1729 auto true_builder = HloComputation::Builder(TestName() + "_true");
1730 auto true_param = true_builder.AddInstruction(
1731 HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
1732 auto true_x = true_builder.AddInstruction(
1733 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
1734 auto true_y = true_builder.AddInstruction(
1735 HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
1736 auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
1737 scalar_shape_, HloOpcode::kAdd, true_x, true_y));
1738 HloComputation* true_computation =
1739 module_->AddEmbeddedComputation(true_builder.Build());
1740
1741 auto false_builder = HloComputation::Builder(TestName() + "_false");
1742 auto false_param = false_builder.AddInstruction(
1743 HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
1744 auto false_x = false_builder.AddInstruction(
1745 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
1746 auto false_y = false_builder.AddInstruction(
1747 HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
1748 auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
1749 scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
1750 HloComputation* false_computation =
1751 module_->AddEmbeddedComputation(false_builder.Build());
1752
1753 auto builder = HloComputation::Builder(TestName());
1754 auto pred = builder.AddInstruction(
1755 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1756 auto constant1 = builder.AddInstruction(
1757 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1758 auto constant2 = builder.AddInstruction(
1759 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1760 auto tuple_operand = builder.AddInstruction(
1761 HloInstruction::CreateTuple({constant1, constant2}));
1762 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1763 scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
1764 false_computation));
1765 module_->AddEntryComputation(builder.Build());
1766 SCOPED_TRACE(module_->ToString());
1767
1768 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1769
1770 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1771 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1772 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1773 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1774 EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
1775 EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
1776
1777 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1778 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1779 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
1780 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
1781 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
1782 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
1783
1784 EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1785 analysis.GetValueDefinedAt(tuple_operand));
1786 EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1787 analysis.GetValueDefinedAt(tuple_operand));
1788 EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
1789 analysis.GetValueDefinedAt(constant1));
1790 EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
1791 analysis.GetValueDefinedAt(constant2));
1792 EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
1793 analysis.GetValueDefinedAt(constant1));
1794 EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
1795 analysis.GetValueDefinedAt(constant2));
1796
1797 EXPECT_THAT(analysis.GetValueDefinedAt(pred).GetUses(),
1798 ElementsAre(HloUse{conditional, 0, {}}));
1799 EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
1800 UnorderedElementsAre(HloUse{conditional, 1, {0}},
1801 HloUse{conditional, 2, {0}},
1802 HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
1803 EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
1804 UnorderedElementsAre(HloUse{conditional, 1, {1}},
1805 HloUse{conditional, 2, {1}},
1806 HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
1807 EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).GetUses(),
1808 UnorderedElementsAre(
1809 HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
1810 HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
1811 HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
1812
1813 bool ssa_form = GetParam();
1814 if (ssa_form) {
1815 EXPECT_EQ(analysis.values().size(), 7);
1816 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1817 } else {
1818 EXPECT_EQ(analysis.values().size(), 6);
1819 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1820 EXPECT_THAT(HloValuesAt(conditional),
1821 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
1822 &analysis.GetValueDefinedAt(sub)));
1823 }
1824 }
1825
TEST_P(HloDataflowAnalysisTest,NestedConditionals)1826 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
1827 // computation1(F32[] %param1):
1828 // %ceil = Ceil(%param1)
1829 // return %ceil
1830 //
1831 // computation2(F32[] %param2):
1832 // %floor = Floor(%param2)
1833 // return %floor
1834 //
1835 // computation3(F32[] %param3):
1836 // %negate = Negate(%param3)
1837 // return %negate
1838 //
1839 // inner_conditional((PRED, F32[], F32[]) %param_cond):
1840 // %pred_cond = GetTupleElement(%param_cond, 0)
1841 // %true_operand_cond = GetTupleElement(%param_cond, 1)
1842 // %false_operand_cond = GetTupleElement(%param_cond, 2)
1843 // return Conditional(%pred_cond, %true_operand_cond, computation1,
1844 // %false_operand_cond, computation2)
1845 //
1846 // entry:
1847 // %pred1 = Constant(true)
1848 // %pred2 = Constant(false)
1849 // %constant1 = Constant(1.1);
1850 // %constant2 = Constant(2.2);
1851 // %constant3 = Constant(3.3);
1852 // return Conditional(%pred1, (%pred2, %constant1, %constant2),
1853 // inner_conditional, %constant3, computation3)
1854
1855 auto computation1 = module_->AddEmbeddedComputation(
1856 CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
1857 auto computation2 = module_->AddEmbeddedComputation(
1858 CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
1859 auto computation3 = module_->AddEmbeddedComputation(
1860 CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
1861
1862 // Build inner_conditional computation.
1863 const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
1864 const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1865 {scalar_bool_shape, scalar_shape_, scalar_shape_});
1866 auto inner_builder =
1867 HloComputation::Builder(TestName() + "_inner_conditional");
1868 auto param_cond = inner_builder.AddInstruction(
1869 HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
1870 auto pred_cond = inner_builder.AddInstruction(
1871 HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
1872 auto true_operand_cond = inner_builder.AddInstruction(
1873 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
1874 auto false_operand_cond = inner_builder.AddInstruction(
1875 HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
1876 auto inner_conditional =
1877 inner_builder.AddInstruction(HloInstruction::CreateConditional(
1878 scalar_shape_, pred_cond, true_operand_cond, computation1,
1879 false_operand_cond, computation2));
1880 auto inner_conditional_computation =
1881 module_->AddEmbeddedComputation(inner_builder.Build());
1882
1883 // Build entry computation.
1884 auto builder = HloComputation::Builder(TestName());
1885 auto pred1 = builder.AddInstruction(
1886 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1887 auto pred2 = builder.AddInstruction(
1888 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1889 auto constant1 = builder.AddInstruction(
1890 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
1891 auto constant2 = builder.AddInstruction(
1892 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
1893 auto constant3 = builder.AddInstruction(
1894 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
1895 auto tuple_operand = builder.AddInstruction(
1896 HloInstruction::CreateTuple({pred2, constant1, constant2}));
1897 auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1898 scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
1899 constant3, computation3));
1900 module_->AddEntryComputation(builder.Build());
1901 SCOPED_TRACE(module_->ToString());
1902
1903 const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1904
1905 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
1906 EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
1907 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1908 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1909 EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
1910 EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1911 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
1912 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
1913 EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
1914
1915 auto computation1_param = computation1->parameter_instruction(0);
1916 auto computation2_param = computation2->parameter_instruction(0);
1917 auto computation3_param = computation3->parameter_instruction(0);
1918 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
1919 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
1920 EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
1921 EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
1922 analysis.GetValueDefinedAt(constant1));
1923 EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
1924 analysis.GetValueDefinedAt(constant2));
1925 EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
1926 analysis.GetValueDefinedAt(constant3));
1927
1928 EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
1929 EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
1930 EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
1931 EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
1932 EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
1933 analysis.GetValueDefinedAt(tuple_operand));
1934 EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
1935 analysis.GetValueDefinedAt(pred2));
1936 EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
1937 analysis.GetValueDefinedAt(constant1));
1938 EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
1939 analysis.GetValueDefinedAt(constant2));
1940
1941 bool ssa_form = GetParam();
1942 if (ssa_form) {
1943 EXPECT_EQ(analysis.values().size(), 11);
1944 EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
1945 EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1946 } else {
1947 EXPECT_EQ(analysis.values().size(), 9);
1948 EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
1949 EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1950 EXPECT_THAT(
1951 HloValuesAt(inner_conditional),
1952 UnorderedElementsAre(
1953 &analysis.GetValueDefinedAt(computation1->root_instruction()),
1954 &analysis.GetValueDefinedAt(computation2->root_instruction())));
1955 EXPECT_THAT(
1956 HloValuesAt(conditional),
1957 UnorderedElementsAre(
1958 &analysis.GetValueDefinedAt(computation1->root_instruction()),
1959 &analysis.GetValueDefinedAt(computation2->root_instruction()),
1960 &analysis.GetValueDefinedAt(computation3->root_instruction())));
1961 }
1962 }
1963
TEST_P(HloDataflowAnalysisTest,AddDependency)1964 TEST_P(HloDataflowAnalysisTest, AddDependency) {
1965 std::string module_string = R"(
1966 HloModule AddDependency
1967 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1968 %p = f32[3] parameter(0)
1969 %token0 = token[] after-all()
1970 ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
1971 }
1972 )";
1973 TF_ASSERT_OK_AND_ASSIGN(
1974 std::unique_ptr<HloModule> module,
1975 ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
1976
1977 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
1978 HloDataflowAnalysis::Run(*module));
1979 const HloInstruction* root = module->entry_computation()->root_instruction();
1980 EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
1981
1982 // The after-all and parameter should define a value. Add-dependency should
1983 // not.
1984 EXPECT_EQ(analysis->values().size(), 2);
1985 EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
1986 }
1987
TEST_F(HloDataflowAnalysisTest,AllReduceStartAndDone)1988 TEST_F(HloDataflowAnalysisTest, AllReduceStartAndDone) {
1989 const char* hlo_text = R"(
1990 HloModule test
1991 add {
1992 x = f32[] parameter(0)
1993 y = f32[] parameter(1)
1994 ROOT add = f32[] add(x, y)
1995 }
1996 ENTRY entry {
1997 p0 = f32[2] parameter(0)
1998 start = f32[2] all-reduce-start(p0), to_apply=add
1999 ROOT done = f32[2] all-reduce-done(start)
2000 }
2001 )";
2002 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2003 ParseAndReturnVerifiedModule(hlo_text));
2004
2005 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
2006 HloDataflowAnalysis::Run(*module));
2007
2008 HloInstruction* done = module->entry_computation()->root_instruction();
2009 HloInstruction* start = done->mutable_operand(0);
2010 HloInstruction* param0 = start->mutable_operand(0);
2011
2012 EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{}));
2013 EXPECT_FALSE(analysis->ValueIsDefinedAt(done));
2014
2015 EXPECT_THAT(analysis->GetValueDefinedAt(param0).GetUses(),
2016 UnorderedElementsAre(HloUse{start, 0, {}}));
2017 EXPECT_THAT(analysis->GetValueDefinedAt(start).GetUses(),
2018 UnorderedElementsAre(HloUse{done, 0, {}}));
2019 }
2020
TEST_F(HloDataflowAnalysisTest,AllReduceStartAndDoneTwoOperands)2021 TEST_F(HloDataflowAnalysisTest, AllReduceStartAndDoneTwoOperands) {
2022 const char* hlo_text = R"(
2023 HloModule test
2024 add {
2025 x = f32[] parameter(0)
2026 y = f32[] parameter(1)
2027 ROOT add = f32[] add(x, y)
2028 }
2029 ENTRY entry {
2030 p0 = f32[2] parameter(0)
2031 p1 = f32[2] parameter(1)
2032 start = (f32[2], f32[2]) all-reduce-start(p0, p1), to_apply=add
2033 ROOT done = (f32[2], f32[2]) all-reduce-done(start)
2034 }
2035 )";
2036 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2037 ParseAndReturnVerifiedModule(hlo_text));
2038
2039 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
2040 HloDataflowAnalysis::Run(*module));
2041
2042 HloInstruction* done = module->entry_computation()->root_instruction();
2043 HloInstruction* start = done->mutable_operand(0);
2044 HloInstruction* param0 = start->mutable_operand(0);
2045 HloInstruction* param1 = start->mutable_operand(1);
2046
2047 EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{}));
2048 EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{0}));
2049 EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{1}));
2050 EXPECT_FALSE(analysis->ValueIsDefinedAt(done));
2051
2052 EXPECT_THAT(analysis->GetValueDefinedAt(param0).GetUses(),
2053 UnorderedElementsAre(HloUse{start, 0, {}}));
2054 EXPECT_THAT(analysis->GetValueDefinedAt(param1).GetUses(),
2055 UnorderedElementsAre(HloUse{start, 1, {}}));
2056 EXPECT_THAT(analysis->GetValueDefinedAt(start, {}).GetUses(),
2057 UnorderedElementsAre(HloUse{done, 0, {}}));
2058 }
2059
2060 INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
2061 HloDataflowAnalysisTest,
2062 ::testing::Values(false, true));
2063
RunAnalysis(const HloModule & module,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)2064 std::unique_ptr<HloDataflowAnalysis> RunAnalysis(
2065 const HloModule& module,
2066 const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
2067 return HloDataflowAnalysis::Run(module, /*ssa_form=*/false,
2068 /*bitcast_defines_value=*/false,
2069 can_share_buffer)
2070 .value();
2071 }
2072
2073 using DoesNotUseOperandBufferTest = HloTestBase;
2074
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)2075 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
2076 auto builder = HloComputation::Builder(TestName());
2077
2078 Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
2079 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2080 0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
2081 auto gte0 = builder.AddInstruction(
2082 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
2083 auto gte1 = builder.AddInstruction(
2084 HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
2085 builder.AddInstruction(
2086 HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
2087
2088 auto module = CreateNewVerifiedModule();
2089 module->AddEntryComputation(builder.Build());
2090 auto dataflow_analysis = RunAnalysis(*module);
2091
2092 // GetTupleElement instructions only access the top-level buffer of their
2093 // operand.
2094 EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {0}, gte0));
2095 EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {1}, gte1));
2096 EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {}, gte0));
2097 EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {}, gte1));
2098 }
2099
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)2100 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
2101 auto builder = HloComputation::Builder(TestName());
2102
2103 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2104 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2105 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2106 auto gte0 = builder.AddInstruction(
2107 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2108 auto gte1 = builder.AddInstruction(
2109 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2110
2111 // Create a DynamicUpdateSlice instruction of tuple element 1.
2112 auto starts = builder.AddInstruction(
2113 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2114 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2115 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2116 auto dynamic_update_slice =
2117 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2118 data_shape, gte1, update,
2119 std::initializer_list<HloInstruction*>({starts})));
2120 builder.AddInstruction(
2121 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2122
2123 auto module = CreateNewVerifiedModule();
2124 auto computation = module->AddEntryComputation(builder.Build());
2125 auto fusion = computation->CreateFusionInstruction(
2126 {dynamic_update_slice, starts, update, gte1},
2127 HloInstruction::FusionKind::kLoop);
2128 auto dataflow_analysis = RunAnalysis(*module);
2129
2130 // The fusion instruction never uses tuple element 0, but does use element 1.
2131 EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2132 EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2133 }
2134
2135 // Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
2136 // parameter tuple.
TEST_F(DoesNotUseOperandBufferTest,IndirectUses)2137 TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
2138 auto builder = HloComputation::Builder(TestName());
2139
2140 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2141 auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
2142 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2143 auto t0 = builder.AddInstruction(
2144 HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
2145 auto t1 = builder.AddInstruction(
2146 HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
2147 // Swap the tuple elements.
2148 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
2149
2150 auto gte0 = builder.AddInstruction(
2151 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2152 auto gte1 = builder.AddInstruction(
2153 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2154
2155 // Create a DynamicUpdateSlice instruction of tuple element 1.
2156 auto starts = builder.AddInstruction(
2157 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2158 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2159 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2160 auto dynamic_update_slice =
2161 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2162 data_shape, gte1, update,
2163 std::initializer_list<HloInstruction*>({starts})));
2164 builder.AddInstruction(
2165 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2166
2167 auto module = CreateNewVerifiedModule();
2168 auto computation = module->AddEntryComputation(builder.Build());
2169 auto fusion = computation->CreateFusionInstruction(
2170 {dynamic_update_slice, starts, update, gte1},
2171 HloInstruction::FusionKind::kLoop);
2172 auto dataflow_analysis = RunAnalysis(*module);
2173
2174 // The fusion instruction never uses tuple element 0, but does use element 1.
2175 EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2176 EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2177 // The same holds for the parameter tuple, except that the tuple elements
2178 // are swapped in 'tuple'.
2179 EXPECT_TRUE(
2180 dataflow_analysis->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
2181 EXPECT_FALSE(
2182 dataflow_analysis->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
2183 }
2184
2185 using CanShareOperandBufferWithUserTest = HloTestBase;
2186
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)2187 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
2188 auto builder = HloComputation::Builder(TestName());
2189
2190 Shape shape = ShapeUtil::MakeShape(F32, {8});
2191 auto param = builder.AddInstruction(
2192 HloInstruction::CreateParameter(0, shape, "param"));
2193 auto exp = builder.AddInstruction(
2194 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2195 auto log = builder.AddInstruction(
2196 HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
2197
2198 auto module = CreateNewVerifiedModule();
2199 module->AddEntryComputation(builder.Build());
2200 auto dataflow_analysis = RunAnalysis(*module);
2201
2202 EXPECT_TRUE(
2203 dataflow_analysis->CanShareOperandBufferWithUser(param, {}, exp, {}));
2204 EXPECT_TRUE(
2205 dataflow_analysis->CanShareOperandBufferWithUser(exp, {}, log, {}));
2206 }
2207
TEST_F(CanShareOperandBufferWithUserTest,NonElementwiseLoopFusionCantAliasOperandBuffer)2208 TEST_F(CanShareOperandBufferWithUserTest,
2209 NonElementwiseLoopFusionCantAliasOperandBuffer) {
2210 auto builder = HloComputation::Builder(TestName());
2211 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2212
2213 auto param0 = builder.AddInstruction(
2214 HloInstruction::CreateParameter(0, data_shape, "param0"));
2215
2216 auto neg = builder.AddInstruction(
2217 HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
2218
2219 auto reverse = builder.AddInstruction(
2220 HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
2221
2222 auto module = CreateNewVerifiedModule();
2223 auto computation = module->AddEntryComputation(builder.Build());
2224 auto fusion = computation->CreateFusionInstruction(
2225 {reverse, neg}, HloInstruction::FusionKind::kLoop);
2226 auto dataflow_analysis = RunAnalysis(*module);
2227
2228 EXPECT_FALSE(
2229 dataflow_analysis->CanShareOperandBufferWithUser(param0, {}, fusion, {}));
2230 }
2231
TEST_F(CanShareOperandBufferWithUserTest,MultiOutputFusionCanAliasOperandBuffer)2232 TEST_F(CanShareOperandBufferWithUserTest,
2233 MultiOutputFusionCanAliasOperandBuffer) {
2234 auto builder = HloComputation::Builder(TestName());
2235 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2236
2237 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2238 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2239 auto param0 = builder.AddInstruction(
2240 HloInstruction::CreateParameter(0, in_shape, "param0"));
2241 auto param1 = builder.AddInstruction(
2242 HloInstruction::CreateParameter(1, in_shape, "param1"));
2243
2244 auto copy0 = builder.AddInstruction(
2245 HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
2246 auto copy1 = builder.AddInstruction(
2247 HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
2248
2249 auto tuple =
2250 builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
2251
2252 auto module = CreateNewVerifiedModule();
2253 auto computation = module->AddEntryComputation(builder.Build());
2254 auto fusion = computation->CreateFusionInstruction(
2255 {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
2256 auto dataflow_analysis = RunAnalysis(*module);
2257
2258 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2259 fusion, {0}));
2260 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2261 fusion, {1}));
2262 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2263 fusion, {0}));
2264 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2265 fusion, {1}));
2266 }
2267
TEST_F(CanShareOperandBufferWithUserTest,ElementwiseLoopFusionCantAliasOperandBuffer)2268 TEST_F(CanShareOperandBufferWithUserTest,
2269 ElementwiseLoopFusionCantAliasOperandBuffer) {
2270 auto builder = HloComputation::Builder(TestName());
2271 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2272
2273 auto one = builder.AddInstruction(
2274 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2275 auto operand = builder.AddInstruction(
2276 HloInstruction::CreateBroadcast(data_shape, one, {}));
2277
2278 auto neg = builder.AddInstruction(
2279 HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
2280
2281 auto exp = builder.AddInstruction(
2282 HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
2283
2284 auto module = CreateNewVerifiedModule();
2285 auto computation = module->AddEntryComputation(builder.Build());
2286 auto fusion = computation->CreateFusionInstruction(
2287 {exp, neg}, HloInstruction::FusionKind::kLoop);
2288 auto dataflow_analysis = RunAnalysis(*module);
2289
2290 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(operand, {},
2291 fusion, {}));
2292 }
2293
TEST_F(CanShareOperandBufferWithUserTest,CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex)2294 TEST_F(CanShareOperandBufferWithUserTest,
2295 CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
2296 auto builder = HloComputation::Builder(TestName());
2297 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2298 Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
2299
2300 auto param = builder.AddInstruction(
2301 HloInstruction::CreateParameter(0, data_shape, "param0"));
2302 auto zero = builder.AddInstruction(
2303 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(0)));
2304 auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2305 slice_shape, param, {zero, zero}, {1, 2}));
2306
2307 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2308 data_shape, param, ds, {zero, zero}));
2309
2310 auto module = CreateNewVerifiedModule();
2311 auto computation = module->AddEntryComputation(builder.Build());
2312 auto fusion = computation->CreateFusionInstruction(
2313 {dus, ds, zero}, HloInstruction::FusionKind::kLoop);
2314 auto dataflow_analysis = RunAnalysis(*module);
2315
2316 EXPECT_TRUE(
2317 dataflow_analysis->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2318 }
2319
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithSameIndices)2320 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
2321 const char* kModule = R"(
2322 HloModule test
2323
2324 fused_computation {
2325 p0 = f32[10,20,30] parameter(0)
2326 p1 = s32[] parameter(1)
2327 p2 = s32[] parameter(2)
2328 p3 = s32[] parameter(3)
2329 slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2330 ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3)
2331 }
2332
2333 ENTRY test {
2334 p0 = f32[10,20,30] parameter(0)
2335 p1 = s32[] parameter(1)
2336 p2 = s32[] parameter(2)
2337 p3 = s32[] parameter(3)
2338 ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2339 }
2340 )";
2341 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2342 auto* fusion = module->entry_computation()->root_instruction();
2343 auto* param = module->entry_computation()->parameter_instruction(0);
2344
2345 auto dataflow_analysis = RunAnalysis(*module);
2346 EXPECT_TRUE(
2347 dataflow_analysis->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2348 }
2349
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)2350 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
2351 auto builder = HloComputation::Builder(TestName());
2352
2353 Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2354 Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2355 auto param0 = builder.AddInstruction(
2356 HloInstruction::CreateParameter(0, in_shape, "param0"));
2357 auto param1 = builder.AddInstruction(
2358 HloInstruction::CreateParameter(1, in_shape, "param1"));
2359 auto result = builder.AddInstruction(HloInstruction::CreateCompare(
2360 out_shape, param0, param1, ComparisonDirection::kEq));
2361
2362 auto module = CreateNewVerifiedModule();
2363 module->AddEntryComputation(builder.Build());
2364 auto dataflow_analysis = RunAnalysis(*module);
2365
2366 EXPECT_FALSE(
2367 dataflow_analysis->CanShareOperandBufferWithUser(param0, {}, result, {}));
2368 EXPECT_FALSE(
2369 dataflow_analysis->CanShareOperandBufferWithUser(param1, {}, result, {}));
2370 }
2371
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)2372 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
2373 auto builder = HloComputation::Builder(TestName());
2374
2375 Shape shape = ShapeUtil::MakeShape(F32, {8});
2376 auto param = builder.AddInstruction(
2377 HloInstruction::CreateParameter(0, shape, "param"));
2378 auto exp = builder.AddInstruction(
2379 HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2380 auto copy = builder.AddInstruction(
2381 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
2382
2383 auto module = CreateNewVerifiedModule();
2384 module->AddEntryComputation(builder.Build());
2385 auto dataflow_analysis = RunAnalysis(*module);
2386
2387 EXPECT_TRUE(
2388 dataflow_analysis->CanShareOperandBufferWithUser(param, {}, exp, {}));
2389 EXPECT_TRUE(
2390 dataflow_analysis->CanShareOperandBufferWithUser(exp, {}, copy, {}));
2391 }
2392
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)2393 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
2394 auto builder = HloComputation::Builder(TestName());
2395
2396 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2397 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2398 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2399 auto gte0 = builder.AddInstruction(
2400 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2401 auto gte1 = builder.AddInstruction(
2402 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2403
2404 // Create a DynamicUpdateSlice instruction of tuple element 1.
2405 auto starts = builder.AddInstruction(
2406 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2407 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2408 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2409 auto dynamic_update_slice =
2410 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2411 data_shape, gte1, update,
2412 std::initializer_list<HloInstruction*>({starts})));
2413 builder.AddInstruction(
2414 HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2415
2416 auto module = CreateNewVerifiedModule();
2417 auto computation = module->AddEntryComputation(builder.Build());
2418 auto fusion = computation->CreateFusionInstruction(
2419 {dynamic_update_slice, starts, update, gte1},
2420 HloInstruction::FusionKind::kLoop);
2421 auto dataflow_analysis = RunAnalysis(*module);
2422
2423 // The fusion instruction can share with tuple element 1.
2424 EXPECT_FALSE(
2425 dataflow_analysis->CanShareOperandBufferWithUser(tuple, {0}, fusion, {}));
2426 EXPECT_TRUE(
2427 dataflow_analysis->CanShareOperandBufferWithUser(tuple, {1}, fusion, {}));
2428 }
2429
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSliceWithConvertCanShare)2430 TEST_F(CanShareOperandBufferWithUserTest,
2431 FusedDynamicUpdateSliceWithConvertCanShare) {
2432 auto builder = HloComputation::Builder(TestName());
2433
2434 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2435 Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
2436 auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2437 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2438 auto gte0 = builder.AddInstruction(
2439 HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2440 auto gte1 = builder.AddInstruction(
2441 HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2442
2443 auto convert1 = builder.AddInstruction(
2444 HloInstruction::CreateConvert(data_shape_bf16, gte1));
2445
2446 // Create a DynamicUpdateSlice instruction of tuple element 1.
2447 auto starts = builder.AddInstruction(
2448 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2449 auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2450 LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2451 auto dynamic_update_slice =
2452 builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2453 data_shape_bf16, convert1, update,
2454 std::initializer_list<HloInstruction*>({starts})));
2455
2456 auto convert2 = builder.AddInstruction(
2457 HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
2458 builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
2459
2460 auto module = CreateNewVerifiedModule();
2461 auto computation = module->AddEntryComputation(builder.Build());
2462 auto fusion = computation->CreateFusionInstruction(
2463 {convert2, dynamic_update_slice, starts, update, convert1},
2464 HloInstruction::FusionKind::kLoop);
2465 auto dataflow_analysis = RunAnalysis(*module);
2466
2467 EXPECT_TRUE(
2468 dataflow_analysis->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
2469 }
2470
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)2471 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
2472 auto builder = HloComputation::Builder(TestName());
2473
2474 Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8});
2475 Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4});
2476 Shape starts_shape = ShapeUtil::MakeShape(S32, {2});
2477 auto data = builder.AddInstruction(
2478 HloInstruction::CreateParameter(0, data_shape, "data"));
2479 auto update = builder.AddInstruction(
2480 HloInstruction::CreateParameter(1, update_shape, "update"));
2481 auto start = builder.AddInstruction(
2482 HloInstruction::CreateParameter(2, starts_shape, "start"));
2483
2484 auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2485 data_shape, data, update, {start}));
2486
2487 auto module = CreateNewVerifiedModule();
2488 module->AddEntryComputation(builder.Build());
2489 auto dataflow_analysis = RunAnalysis(*module);
2490
2491 // The DynamicUpdateSlice instruction can share with the data operand, but not
2492 // with update or start.
2493 EXPECT_TRUE(
2494 dataflow_analysis->CanShareOperandBufferWithUser(data, {}, dus, {}));
2495 EXPECT_FALSE(
2496 dataflow_analysis->CanShareOperandBufferWithUser(update, {}, dus, {}));
2497 EXPECT_FALSE(
2498 dataflow_analysis->CanShareOperandBufferWithUser(start, {}, dus, {}));
2499 }
2500
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)2501 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
2502 const char* hlo_text = R"(
2503 HloModule TensorFlowScatterV1
2504
2505 update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2506 lhs = s32[] parameter(0)
2507 ROOT rhs = s32[] parameter(1)
2508 }
2509
2510 ENTRY main {
2511 operand = s32[3,3] parameter(0)
2512 indices = s32[2] parameter(1)
2513 updates = s32[2,3] parameter(2)
2514 ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2515 to_apply=update_s32,
2516 update_window_dims={1},
2517 inserted_window_dims={0},
2518 scatter_dims_to_operand_dims={0},
2519 index_vector_dim=1
2520 }
2521 )";
2522 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
2523 auto computation = module->entry_computation();
2524 auto dataflow_analysis = RunAnalysis(*module);
2525
2526 HloInstruction* operand_param = computation->parameter_instruction(0);
2527 HloInstruction* indices_param = computation->parameter_instruction(1);
2528 HloInstruction* updates_param = computation->parameter_instruction(2);
2529 HloInstruction* scatter = computation->root_instruction();
2530
2531 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2532 operand_param, {}, scatter, {}));
2533 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2534 indices_param, {}, scatter, {}));
2535 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2536 updates_param, {}, scatter, {}));
2537 }
2538
TEST_F(CanShareOperandBufferWithUserTest,MultioutputScatterCanShare)2539 TEST_F(CanShareOperandBufferWithUserTest, MultioutputScatterCanShare) {
2540 const char* hlo_text = R"(
2541 HloModule MultioutputScatter
2542
2543 update {
2544 lhs0 = s32[] parameter(0)
2545 lhs1 = f32[] parameter(1)
2546 rhs0 = s32[] parameter(2)
2547 rhs1 = f32[] parameter(3)
2548 ROOT tuple = tuple(rhs0, rhs1)
2549 }
2550
2551 ENTRY main {
2552 operand0 = s32[3,3] parameter(0)
2553 operand1 = f32[3,3] parameter(1)
2554 indices = s32[2] parameter(2)
2555 updates0 = s32[2,3] parameter(3)
2556 updates1 = f32[2,3] parameter(4)
2557 ROOT scatter = (s32[3,3], f32[3,3])
2558 scatter(operand0, operand1, indices, updates0, updates1),
2559 to_apply=update,
2560 update_window_dims={1},
2561 inserted_window_dims={0},
2562 scatter_dims_to_operand_dims={0},
2563 index_vector_dim=1
2564 }
2565 )";
2566 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
2567 auto computation = module->entry_computation();
2568 auto dataflow_analysis = RunAnalysis(*module);
2569
2570 HloInstruction* operand0_param = computation->parameter_instruction(0);
2571 HloInstruction* operand1_param = computation->parameter_instruction(1);
2572 HloInstruction* indices_param = computation->parameter_instruction(2);
2573 HloInstruction* updates0_param = computation->parameter_instruction(3);
2574 HloInstruction* updates1_param = computation->parameter_instruction(4);
2575 HloInstruction* scatter = computation->root_instruction();
2576
2577 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2578 operand0_param, {}, scatter, {0}));
2579 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2580 operand0_param, {}, scatter, {1}));
2581 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2582 operand1_param, {}, scatter, {0}));
2583 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2584 operand1_param, {}, scatter, {1}));
2585 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2586 indices_param, {}, scatter, {0}));
2587 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2588 indices_param, {}, scatter, {1}));
2589 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2590 updates0_param, {}, scatter, {0}));
2591 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2592 updates0_param, {}, scatter, {1}));
2593 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2594 updates1_param, {}, scatter, {0}));
2595 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2596 updates1_param, {}, scatter, {1}));
2597 }
2598
TEST_F(CanShareOperandBufferWithUserTest,TriangularSolveCanShare)2599 TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) {
2600 const char* hlo_text = R"(
2601 HloModule TensorFlowTriangularSolve
2602
2603 ENTRY main {
2604 a = f32[4,4]{1,0} parameter(0)
2605 b = f32[3,4]{1,0} parameter(1)
2606 ROOT triangular-solve = f32[3,4]{1,0} triangular-solve(a, b), lower=true,
2607 transpose_a=NO_TRANSPOSE
2608 }
2609 )";
2610 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
2611 auto computation = module->entry_computation();
2612 auto dataflow_analysis = RunAnalysis(*module);
2613
2614 HloInstruction* lhs_param = computation->parameter_instruction(0);
2615 HloInstruction* rhs_param = computation->parameter_instruction(1);
2616 HloInstruction* triangular_solve = computation->root_instruction();
2617
2618 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2619 lhs_param, {}, triangular_solve, {}));
2620 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2621 rhs_param, {}, triangular_solve, {}));
2622 }
2623
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)2624 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
2625 auto builder = HloComputation::Builder(TestName());
2626 auto module = CreateNewVerifiedModule();
2627
2628 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2629 auto keys = builder.AddInstruction(
2630 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2631 TF_ASSERT_OK_AND_ASSIGN(
2632 auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
2633 &builder, module.get()));
2634
2635 module->AddEntryComputation(builder.Build());
2636 auto dataflow_analysis = RunAnalysis(*module);
2637
2638 EXPECT_TRUE(
2639 dataflow_analysis->CanShareOperandBufferWithUser(keys, {}, sort, {}));
2640 }
2641
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)2642 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
2643 auto builder = HloComputation::Builder(TestName());
2644 auto module = CreateNewVerifiedModule();
2645
2646 Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2647 Shape values_shape = ShapeUtil::MakeShape(F32, {8});
2648 auto keys = builder.AddInstruction(
2649 HloInstruction::CreateParameter(0, keys_shape, "keys"));
2650 auto values = builder.AddInstruction(
2651 HloInstruction::CreateParameter(1, values_shape, "values"));
2652 TF_ASSERT_OK_AND_ASSIGN(
2653 auto* sort,
2654 MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
2655 {keys, values}, 0, /*is_stable=*/false, &builder,
2656 module.get()));
2657
2658 module->AddEntryComputation(builder.Build());
2659 auto dataflow_analysis = RunAnalysis(*module);
2660
2661 // The buffer for the keys can be shared with the first tuple entry.
2662 EXPECT_TRUE(
2663 dataflow_analysis->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
2664 // The buffer for the values can be shared with the second tuple entry.
2665 EXPECT_TRUE(
2666 dataflow_analysis->CanShareOperandBufferWithUser(values, {}, sort, {1}));
2667 // Verify that the buffers are not shared with the "wrong" tuple entry.
2668 EXPECT_FALSE(
2669 dataflow_analysis->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
2670 EXPECT_FALSE(
2671 dataflow_analysis->CanShareOperandBufferWithUser(values, {}, sort, {0}));
2672 }
2673
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)2674 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
2675 auto builder = HloComputation::Builder(TestName());
2676 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2677
2678 auto a = builder.AddInstruction(HloInstruction::CreateConstant(
2679 LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
2680 auto b = builder.AddInstruction(HloInstruction::CreateConstant(
2681 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2682
2683 DotDimensionNumbers dot_dnums;
2684 dot_dnums.add_lhs_contracting_dimensions(1);
2685 dot_dnums.add_rhs_contracting_dimensions(0);
2686 PrecisionConfig precision_config;
2687 precision_config.mutable_operand_precision()->Resize(
2688 2, PrecisionConfig::DEFAULT);
2689 auto dot = builder.AddInstruction(
2690 HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
2691
2692 auto one = builder.AddInstruction(
2693 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2694 auto add_operand = builder.AddInstruction(
2695 HloInstruction::CreateBroadcast(data_shape, one, {}));
2696
2697 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
2698 data_shape, HloOpcode::kAdd, dot, add_operand));
2699
2700 auto module = CreateNewVerifiedModule();
2701 auto computation = module->AddEntryComputation(builder.Build());
2702 auto fusion = computation->CreateFusionInstruction(
2703 {add, dot}, HloInstruction::FusionKind::kOutput);
2704 auto dataflow_analysis = RunAnalysis(*module);
2705
2706 // Output fused dot add should be able to share buffer with 'add_operand'.
2707 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(add_operand, {},
2708 fusion, {}));
2709 }
2710
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)2711 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
2712 auto builder = HloComputation::Builder(TestName());
2713 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2714
2715 auto one = builder.AddInstruction(
2716 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2717 auto operand = builder.AddInstruction(
2718 HloInstruction::CreateBroadcast(data_shape, one, {}));
2719
2720 auto reverse = builder.AddInstruction(
2721 HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
2722
2723 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2724 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2725
2726 auto add = builder.AddInstruction(
2727 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
2728
2729 auto module = CreateNewVerifiedModule();
2730 auto computation = module->AddEntryComputation(builder.Build());
2731 auto fusion = computation->CreateFusionInstruction(
2732 {add, two, reverse}, HloInstruction::FusionKind::kOutput);
2733 auto dataflow_analysis = RunAnalysis(*module);
2734
2735 // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
2736 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(operand, {},
2737 fusion, {}));
2738 }
2739
TEST_F(CanShareOperandBufferWithUserTest,FusionCanShareBufferCustomized)2740 TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
2741 auto builder = HloComputation::Builder(TestName());
2742 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2743
2744 auto one = builder.AddInstruction(
2745 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2746 auto operand = builder.AddInstruction(
2747 HloInstruction::CreateBroadcast(data_shape, one, {}));
2748 auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
2749 data_shape, HloOpcode::kMultiply, operand, operand));
2750 auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2751 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2752 auto add = builder.AddInstruction(
2753 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
2754
2755 auto module = CreateNewVerifiedModule();
2756 auto computation = module->AddEntryComputation(builder.Build());
2757 auto fusion = computation->CreateFusionInstruction(
2758 {add, two, mul}, HloInstruction::FusionKind::kInput);
2759 auto dataflow_analysis = RunAnalysis(
2760 *module,
2761 /*can_share_buffer=*/[](const HloInstruction* fusion,
2762 const HloInstruction*, const ShapeIndex&) {
2763 return fusion->IsLoopFusion();
2764 });
2765
2766 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(operand, {},
2767 fusion, {}));
2768 }
2769
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)2770 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
2771 auto module = CreateNewVerifiedModule();
2772 Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2773 Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {});
2774
2775 auto b = HloComputation::Builder(TestName() + ".And");
2776 auto p0 = b.AddInstruction(
2777 HloInstruction::CreateParameter(0, pred_scalar_shape, "p0"));
2778 auto p1 = b.AddInstruction(
2779 HloInstruction::CreateParameter(1, pred_scalar_shape, "p1"));
2780 b.AddInstruction(
2781 HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1));
2782 auto and_computation = module->AddEmbeddedComputation(b.Build());
2783
2784 auto make_cond = [&data_shape, &and_computation]() {
2785 auto builder = HloComputation::Builder(TestName() + ".Cond");
2786 auto data = builder.AddInstruction(
2787 HloInstruction::CreateParameter(0, data_shape, "data"));
2788 auto compare = builder.AddInstruction(HloInstruction::CreateCompare(
2789 ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq));
2790 auto true_value = builder.AddInstruction(
2791 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2792 builder.AddInstruction(
2793 HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare,
2794 true_value, {0}, and_computation));
2795 return builder.Build();
2796 };
2797
2798 auto make_body = [&data_shape]() {
2799 auto builder = HloComputation::Builder(TestName() + ".Body");
2800 auto data = builder.AddInstruction(
2801 HloInstruction::CreateParameter(0, data_shape, "data"));
2802 builder.AddInstruction(
2803 HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
2804 return builder.Build();
2805 };
2806
2807 HloComputation* cond_computation =
2808 module->AddEmbeddedComputation(make_cond());
2809 HloComputation* body_computation =
2810 module->AddEmbeddedComputation(make_body());
2811
2812 auto builder = HloComputation::Builder(TestName());
2813 auto data = builder.AddInstruction(
2814 HloInstruction::CreateParameter(0, data_shape, "data"));
2815 auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
2816 data_shape, cond_computation, body_computation, data));
2817 module->AddEntryComputation(builder.Build());
2818
2819 auto dataflow_analysis = RunAnalysis(*module);
2820
2821 // The While instruction can share with the data operand.
2822 EXPECT_TRUE(
2823 dataflow_analysis->CanShareOperandBufferWithUser(data, {}, whil, {}));
2824 }
2825
2826 // Tests that Call can alias operand buffer if the only use of the operand
2827 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)2828 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
2829 Shape shape = ShapeUtil::MakeShape(F32, {8});
2830 // Build sub-computation with fusion root.
2831 auto sub_builder = HloComputation::Builder(TestName() + "_sub");
2832 auto sub_param = sub_builder.AddInstruction(
2833 HloInstruction::CreateParameter(0, shape, "sub_param"));
2834 auto one = sub_builder.AddInstruction(
2835 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2836 auto ones = sub_builder.AddInstruction(
2837 HloInstruction::CreateBroadcast(shape, one, {}));
2838 auto add = sub_builder.AddInstruction(
2839 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
2840
2841 auto module = CreateNewVerifiedModule();
2842 auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
2843 sub_computation->CreateFusionInstruction({add, ones},
2844 HloInstruction::FusionKind::kLoop);
2845
2846 // Build entry-computation with kCall which calls 'sub_computation'.
2847 auto builder = HloComputation::Builder(TestName());
2848
2849 auto param = builder.AddInstruction(
2850 HloInstruction::CreateParameter(0, shape, "param"));
2851 auto reverse =
2852 builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
2853 auto call = builder.AddInstruction(
2854 HloInstruction::CreateCall(shape, {reverse}, sub_computation));
2855 module->AddEntryComputation(builder.Build());
2856
2857 auto dataflow_analysis = RunAnalysis(*module);
2858
2859 EXPECT_TRUE(
2860 dataflow_analysis->CanShareOperandBufferWithUser(reverse, {}, call, {}));
2861 }
2862
TEST_F(CanShareOperandBufferWithUserTest,ConcatSliceWithElementwise)2863 TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) {
2864 const char* kModule = R"(
2865 HloModule test
2866
2867 fused_computation {
2868 p0 = f32[10,20] parameter(0)
2869 p1 = f32[10,20] parameter(1)
2870 p2 = f32[10,10] parameter(2)
2871 p3 = f32[10,10] parameter(3)
2872 add0 = f32[10, 20] add(p0, p1)
2873 sub0 = f32[10, 10] subtract(p2, p3)
2874 reshape0 = f32[200] reshape(add0)
2875 reshape1 = f32[100] reshape(sub0)
2876 concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
2877 slice0 = f32[200] slice(concat0), slice={[0:200]}
2878 slice1 = f32[100] slice(concat0), slice={[200:300]}
2879 ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
2880 }
2881
2882 ENTRY test {
2883 p0 = f32[10,20] parameter(0)
2884 p1 = f32[10,20] parameter(1)
2885 p2 = f32[10,10] parameter(2)
2886 p3 = f32[10,10] parameter(3)
2887 ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
2888 }
2889 )";
2890 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2891 auto* fusion = module->entry_computation()->root_instruction();
2892 auto* param0 = module->entry_computation()->parameter_instruction(0);
2893 auto* param1 = module->entry_computation()->parameter_instruction(1);
2894 auto* param2 = module->entry_computation()->parameter_instruction(2);
2895 auto* param3 = module->entry_computation()->parameter_instruction(3);
2896
2897 auto dataflow_analysis = RunAnalysis(*module);
2898 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2899 fusion, {0}));
2900 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2901 fusion, {0}));
2902 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param2, {},
2903 fusion, {1}));
2904 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param3, {},
2905 fusion, {1}));
2906 // Tensors of different sizes cannot share buffer.
2907 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2908 fusion, {1}));
2909 }
2910
TEST_F(CanShareOperandBufferWithUserTest,ConcatSliceNegativeTest)2911 TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) {
2912 const char* kModule = R"(
2913 HloModule test
2914
2915 fused_computation {
2916 // p0 has multiple transitive uses fed to concat. So, p0 cannot share
2917 // buffer with outputs because the aliased output could be written before
2918 // all the uses of p0 are finished.
2919 p0 = f32[100] parameter(0)
2920 p1 = f32[100] parameter(1)
2921 add0 = f32[100] add(p0, p1)
2922 concat0 = f32[200] concatenate(p0, add0), dimensions={0}
2923 slice0 = f32[100] slice(concat0), slice={[0:100]}
2924 slice1 = f32[100] slice(concat0), slice={[100:200]}
2925 ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1)
2926 }
2927
2928 ENTRY test {
2929 p0 = f32[100] parameter(0)
2930 p1 = f32[100] parameter(1)
2931 ROOT fusion = (f32[100], f32[100]) fusion(p0, p1),
2932 kind=kInput, calls=fused_computation
2933 }
2934 )";
2935 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2936 auto* fusion = module->entry_computation()->root_instruction();
2937 auto* param0 = module->entry_computation()->parameter_instruction(0);
2938 auto* param1 = module->entry_computation()->parameter_instruction(1);
2939
2940 auto dataflow_analysis = RunAnalysis(*module);
2941 // p0 cannot share with either fusion{0} or fusion{1}.
2942 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2943 fusion, {0}));
2944 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2945 fusion, {1}));
2946 // p1 cannot share with fusion{0} because we're not sure about their
2947 // relationship.
2948 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2949 fusion, {0}));
2950 // p1 can share with fusion{1} because they will be executed in an
2951 // elementwise manner.
2952 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2953 fusion, {1}));
2954 }
2955
TEST_F(CanShareOperandBufferWithUserTest,MultipleConcatenates)2956 TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) {
2957 const char* kModule = R"(
2958 HloModule test
2959
2960 fused_computation {
2961 p0 = f32[100] parameter(0)
2962 p1 = f32[100] parameter(1)
2963 add0 = f32[100] add(p0, p1)
2964 sub0 = f32[100] subtract(p1, p1)
2965 concat0 = f32[200] concatenate(p0, add0), dimensions={0}
2966 slice0 = f32[100] slice(concat0), slice={[0:100]}
2967 slice1 = f32[100] slice(concat0), slice={[100:200]}
2968 concat1 = f32[200] concatenate(p0, sub0), dimensions={0}
2969 slice2 = f32[100] slice(concat1), slice={[0:100]}
2970 slice3 = f32[100] slice(concat1), slice={[100:200]}
2971 ROOT tuple = (f32[100], f32[100], f32[100], f32[100])
2972 tuple(slice0, slice1, slice2, slice3)
2973 }
2974
2975 ENTRY test {
2976 p0 = f32[100] parameter(0)
2977 p1 = f32[100] parameter(1)
2978 ROOT fusion = (f32[100], f32[100], f32[100], f32[100])
2979 fusion(p0, p1), kind=kInput, calls=fused_computation
2980 }
2981 )";
2982 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2983 auto* fusion = module->entry_computation()->root_instruction();
2984 auto* param0 = module->entry_computation()->parameter_instruction(0);
2985 auto* param1 = module->entry_computation()->parameter_instruction(1);
2986
2987 auto dataflow_analysis = RunAnalysis(*module);
2988 // p0 cannot share.
2989 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2990 fusion, {0}));
2991 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2992 fusion, {1}));
2993 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2994 fusion, {2}));
2995 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2996 fusion, {3}));
2997 // p1 can share with either fusion{1} or fusion{3}.
2998 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2999 fusion, {1}));
3000 EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
3001 fusion, {3}));
3002 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
3003 fusion, {0}));
3004 EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
3005 fusion, {2}));
3006 }
3007
3008 using GetInPlaceInputOutputPairsTest = HloTestBase;
3009
TEST_F(GetInPlaceInputOutputPairsTest,DUS)3010 TEST_F(GetInPlaceInputOutputPairsTest, DUS) {
3011 const char* kModule = R"(
3012 HloModule test
3013
3014 ENTRY test {
3015 p0 = f32[10] parameter(0)
3016 p1 = f32[5] parameter(1)
3017 p2 = s32[] parameter(2)
3018 ROOT dus = f32[10] dynamic-update-slice(p0, p1, p2)
3019 }
3020 )";
3021 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3022 HloInstruction* dus = module->entry_computation()->root_instruction();
3023
3024 auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(dus);
3025 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3026 expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
3027 EXPECT_EQ(in_place_pairs, expected_pairs);
3028 }
3029
TEST_F(GetInPlaceInputOutputPairsTest,DUSFusion)3030 TEST_F(GetInPlaceInputOutputPairsTest, DUSFusion) {
3031 const char* kModule = R"(
3032 HloModule test
3033
3034 fused_computation {
3035 p0 = f32[10] parameter(0)
3036 p1 = f32[5] parameter(1)
3037 p2 = s32[] parameter(2)
3038 ROOT dus = f32[10] dynamic-update-slice(p0, p1, p2)
3039 }
3040
3041 ENTRY test {
3042 p0 = f32[10] parameter(0)
3043 p1 = f32[5] parameter(1)
3044 p2 = s32[] parameter(2)
3045 ROOT fusion = f32[10] fusion(p0, p1, p2), kind=kLoop, calls=fused_computation
3046 }
3047 )";
3048 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3049 HloInstruction* fusion = module->entry_computation()->root_instruction();
3050
3051 auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3052 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3053 expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
3054 EXPECT_EQ(in_place_pairs, expected_pairs);
3055 }
3056
TEST_F(GetInPlaceInputOutputPairsTest,NonDUSFusion)3057 TEST_F(GetInPlaceInputOutputPairsTest, NonDUSFusion) {
3058 const char* kModule = R"(
3059 HloModule test
3060
3061 fused_computation {
3062 p0 = f32[10] parameter(0)
3063 p1 = f32[10] parameter(1)
3064 ROOT add = f32[10] add(p0, p1)
3065 }
3066
3067 ENTRY test {
3068 p0 = f32[10] parameter(0)
3069 p1 = f32[10] parameter(1)
3070 ROOT fusion = f32[10] fusion(p0, p1), kind=kLoop, calls=fused_computation
3071 }
3072 )";
3073 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3074 HloInstruction* fusion = module->entry_computation()->root_instruction();
3075
3076 auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3077 EXPECT_THAT(in_place_pairs, IsEmpty());
3078 }
3079
TEST_F(GetInPlaceInputOutputPairsTest,NestedDUSFusion)3080 TEST_F(GetInPlaceInputOutputPairsTest, NestedDUSFusion) {
3081 const char* kModule = R"(
3082 HloModule test
3083
3084 fused_computation1 {
3085 p0 = f32[10] parameter(0)
3086 p1 = f32[5] parameter(1)
3087 p2 = s32[] parameter(2)
3088 ROOT dus = f32[10] dynamic-update-slice(p0, p1, p2)
3089 }
3090
3091 fused_computation2 {
3092 p0 = f32[10] parameter(0)
3093 p1 = f32[5] parameter(1)
3094 p2 = s32[] parameter(2)
3095 ROOT fusion = f32[10] fusion(p0, p1, p2), kind=kLoop, calls=fused_computation1
3096 }
3097
3098 ENTRY test {
3099 p0 = f32[10] parameter(0)
3100 p1 = f32[5] parameter(1)
3101 p2 = s32[] parameter(2)
3102 ROOT fusion = f32[10] fusion(p0, p1, p2), kind=kLoop, calls=fused_computation2
3103 }
3104 )";
3105 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3106 HloInstruction* fusion = module->entry_computation()->root_instruction();
3107
3108 auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3109 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3110 expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
3111 EXPECT_EQ(in_place_pairs, expected_pairs);
3112 }
3113
TEST_F(GetInPlaceInputOutputPairsTest,NestedMultiOutputDUSFusion)3114 TEST_F(GetInPlaceInputOutputPairsTest, NestedMultiOutputDUSFusion) {
3115 const char* kModule = R"(
3116 HloModule test
3117
3118 fused_computation1 {
3119 p0 = s32[] parameter(0)
3120 p1 = (f32[5],f32[10]) parameter(1)
3121 gte0 = f32[5] get-tuple-element(p1), index=0
3122 gte1 = f32[10] get-tuple-element(p1), index=1
3123 dus = f32[10] dynamic-update-slice(gte1, gte0, p0)
3124 negate = f32[5] negate(gte0)
3125 ROOT tuple = (f32[5],f32[10]) tuple(negate, dus)
3126 }
3127
3128 fused_computation2 {
3129 p0 = f32[5] parameter(0)
3130 p1 = (f32[10],s32[]) parameter(1)
3131 gte0 = f32[10] get-tuple-element(p1), index=0
3132 gte1 = s32[] get-tuple-element(p1), index=1
3133 in_tuple = (f32[5],f32[10]) tuple(p0, gte0)
3134 inner_fusion = (f32[5],f32[10]) fusion(gte1, in_tuple), kind=kLoop, calls=fused_computation1
3135 fusion_gte0 = f32[5] get-tuple-element(inner_fusion), index=0
3136 fusion_gte1 = f32[10] get-tuple-element(inner_fusion), index=1
3137 negate = f32[5] negate(p0)
3138 ROOT tuple = (f32[5],f32[5],f32[10]) tuple(negate, fusion_gte0, fusion_gte1)
3139 }
3140
3141 ENTRY test {
3142 p0 = f32[5] parameter(0)
3143 p1 = (f32[10],s32[]) parameter(1)
3144 ROOT fusion = (f32[5],f32[5],f32[10]) fusion(p0, p1), kind=kLoop, calls=fused_computation2
3145 }
3146 )";
3147 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3148 HloInstruction* fusion = module->entry_computation()->root_instruction();
3149 HloInstruction* inner_fusion = FindInstruction(module.get(), "inner_fusion");
3150
3151 auto inner_in_place_pairs =
3152 HloDataflowAnalysis::GetInPlaceInputOutputPairs(inner_fusion);
3153 std::vector<std::pair<HloOperandIndex, ShapeIndex>> inner_expected_pairs;
3154 inner_expected_pairs.push_back({HloOperandIndex{1, {1}}, {1}});
3155 EXPECT_EQ(inner_in_place_pairs, inner_expected_pairs);
3156
3157 auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3158 std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3159 expected_pairs.push_back({HloOperandIndex{1, {0}}, {2}});
3160 EXPECT_EQ(in_place_pairs, expected_pairs);
3161 }
3162
3163 } // namespace
3164 } // namespace xla
3165