xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include <string>
19 
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/async_op_canonicalizer.h"
22 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
30 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/status_test_util.h"
38 #include "tensorflow/core/platform/logging.h"
39 #include "tensorflow/core/platform/test.h"
40 
41 namespace xla {
42 namespace {
43 
44 using ::testing::ElementsAre;
45 using ::testing::IsEmpty;
46 using ::testing::UnorderedElementsAre;
47 
48 // Test is parameterized on a bool which is whether the dataflow analysis is
49 // performed with SSA form.
50 class HloDataflowAnalysisTest : public HloTestBase,
51                                 public ::testing::WithParamInterface<bool> {
52  protected:
HloDataflowAnalysisTest()53   HloDataflowAnalysisTest() : module_(CreateNewVerifiedModule()) {}
54 
55   // Run dataflow analysis on the member module. For convenience returns a
56   // reference to the generated analysis stored in analysis_.
RunAnalysis(bool ssa_form,bool bitcast_defines_value=false,bool run_dce=true)57   const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
58                                          bool bitcast_defines_value = false,
59                                          bool run_dce = true) {
60     AsyncOpCanonicalizer async_op_canonicalizer;
61     EXPECT_TRUE(async_op_canonicalizer.Run(module_.get()).ok());
62     if (run_dce) {
63       HloDCE dce;
64       EXPECT_TRUE(dce.Run(module_.get()).ok());
65     }
66     FlattenCallGraph flatten;
67     EXPECT_TRUE(flatten.Run(module_.get()).ok());
68     analysis_ =
69         HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value)
70             .value();
71     return *analysis_;
72   }
73 
74   // Return a vector of the HloValues at the given program position.
HloValuesAt(const HloInstruction * instruction,const ShapeIndex & index={})75   const std::vector<const HloValue*>& HloValuesAt(
76       const HloInstruction* instruction, const ShapeIndex& index = {}) {
77     CHECK(analysis_ != nullptr);
78     return analysis_->GetValueSet(instruction, index).values();
79   }
80 
81   // Returns true if the top-level values for instructions 'a' and 'b' may
82   // interfere. Precondition: 'a' and 'b' define array-shaped values.
InstructionsMayInterfere(const HloOrdering & ordering,const HloInstruction * a,const HloInstruction * b)83   bool InstructionsMayInterfere(const HloOrdering& ordering,
84                                 const HloInstruction* a,
85                                 const HloInstruction* b) {
86     EXPECT_FALSE(a->shape().IsTuple());
87     EXPECT_FALSE(b->shape().IsTuple());
88     return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
89                                  analysis_->GetValueDefinedAt(b), *analysis_);
90   }
91 
CreateR0F32UnaryOpComputation(HloOpcode opcode)92   std::unique_ptr<HloComputation> CreateR0F32UnaryOpComputation(
93       HloOpcode opcode) {
94     HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode));
95     HloInstruction* param0 = builder.AddInstruction(
96         HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
97     builder.AddInstruction(
98         HloInstruction::CreateUnary(scalar_shape_, opcode, param0));
99     return builder.Build();
100   }
101 
102   std::unique_ptr<HloModule> module_;
103   std::unique_ptr<HloDataflowAnalysis> analysis_;
104 
105   const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {});
106   const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42});
107   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
108       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
109 };
110 
TEST_P(HloDataflowAnalysisTest,BinaryOperation)111 TEST_P(HloDataflowAnalysisTest, BinaryOperation) {
112   // Test the dataflow for a simple binary operation (Add).
113   auto builder = HloComputation::Builder(TestName());
114   auto constant1 = builder.AddInstruction(
115       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
116   auto constant2 = builder.AddInstruction(
117       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
118   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
119       scalar_shape_, HloOpcode::kAdd, constant1, constant2));
120   module_->AddEntryComputation(builder.Build());
121   SCOPED_TRACE(module_->ToString());
122 
123   bool ssa_form = GetParam();
124   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
125 
126   // Each instruction should define a single value.
127   EXPECT_EQ(analysis.values().size(), 3);
128   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
129   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
130   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
131 
132   // Verify the positions of the values. These positions are all trivial because
133   // there are no instructions which forward values.
134   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).positions(),
135               UnorderedElementsAre(HloPosition{constant1, {}}));
136   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).positions(),
137               UnorderedElementsAre(HloPosition{constant2, {}}));
138   EXPECT_THAT(analysis.GetValueDefinedAt(add).positions(),
139               UnorderedElementsAre(HloPosition{add, {}}));
140 
141   // Verify the uses of the values.
142   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
143               UnorderedElementsAre(HloUse{add, 0, {}}));
144   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
145               UnorderedElementsAre(HloUse{add, 1, {}}));
146   EXPECT_TRUE(analysis.GetValueDefinedAt(add).GetUses().empty());
147 
148   // Verify liveout values from the module.
149   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
150   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
151   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
152 }
153 
TEST_P(HloDataflowAnalysisTest,TupleAndGtes)154 TEST_P(HloDataflowAnalysisTest, TupleAndGtes) {
155   // Verify the dataflow through a Tuple and GetTupleElement instructions.
156   auto builder = HloComputation::Builder(TestName());
157   auto param0 = builder.AddInstruction(
158       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
159   auto param1 = builder.AddInstruction(
160       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
161   auto tuple =
162       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
163   auto gte0 = builder.AddInstruction(
164       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 0));
165   auto gte1 = builder.AddInstruction(
166       HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1));
167   auto add = builder.AddInstruction(
168       HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1));
169   module_->AddEntryComputation(builder.Build());
170   SCOPED_TRACE(module_->ToString());
171 
172   bool ssa_form = GetParam();
173   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
174 
175   // The two params, tuple, and add should each define one value.
176   EXPECT_EQ(analysis.values().size(), 4);
177 
178   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
179   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
180   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
181   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
182   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
183   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte0));
184   EXPECT_FALSE(analysis.ValueIsDefinedAt(gte1));
185   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
186 
187   // Verify the positions of the values.
188   EXPECT_THAT(
189       analysis.GetValueDefinedAt(param0).positions(),
190       UnorderedElementsAre(HloPosition{param0, {}}, HloPosition{tuple, {0}},
191                            HloPosition{gte0, {}}));
192   EXPECT_THAT(
193       analysis.GetValueDefinedAt(param1).positions(),
194       UnorderedElementsAre(HloPosition{param1, {}}, HloPosition{tuple, {1}},
195                            HloPosition{gte1, {}}));
196   EXPECT_THAT(analysis.GetValueDefinedAt(tuple).positions(),
197               UnorderedElementsAre(HloPosition{tuple, {}}));
198 
199   // Verify uses. Of interest is that a GetTupleElement instruction is only a
200   // use of the top-level value in the tuple operand.
201   EXPECT_THAT(analysis.GetValueDefinedAt(param0).GetUses(),
202               UnorderedElementsAre(HloUse{add, 0, {}}));
203   EXPECT_THAT(analysis.GetValueDefinedAt(param1).GetUses(),
204               UnorderedElementsAre(HloUse{add, 1, {}}));
205   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).GetUses(),
206               UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}}));
207   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
208 }
209 
TEST_P(HloDataflowAnalysisTest,NestedTuple)210 TEST_P(HloDataflowAnalysisTest, NestedTuple) {
211   // Verify the dataflow through a nested tuple.
212   auto builder = HloComputation::Builder(TestName());
213   auto constant1 = builder.AddInstruction(
214       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
215   auto constant2 = builder.AddInstruction(
216       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
217   auto tuple = builder.AddInstruction(
218       HloInstruction::CreateTuple({constant1, constant2}));
219   auto nested_tuple = builder.AddInstruction(
220       HloInstruction::CreateTuple({tuple, tuple, constant1}));
221   auto gte_tuple = builder.AddInstruction(
222       HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1));
223   auto gte_out = builder.AddInstruction(
224       HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0));
225   module_->AddEntryComputation(builder.Build());
226   SCOPED_TRACE(module_->ToString());
227 
228   bool ssa_form = GetParam();
229   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
230 
231   EXPECT_EQ(analysis.values().size(), 4);
232 
233   // Verify positions and uses.
234   EXPECT_THAT(
235       analysis.GetValueDefinedAt(constant1).positions(),
236       UnorderedElementsAre(
237           HloPosition{constant1, {}}, HloPosition{tuple, {0}},
238           HloPosition{nested_tuple, {0, 0}}, HloPosition{nested_tuple, {1, 0}},
239           HloPosition{nested_tuple, {2}}, HloPosition{gte_tuple, {0}},
240           HloPosition{gte_out, {}}));
241   // Constant values should have only a single use, which is the root of the
242   // computation.
243   EXPECT_THAT(analysis.GetValueDefinedAt(constant1, /*index=*/{}).GetUses(),
244               UnorderedElementsAre(HloUse{gte_out, 0, {0}}));
245   EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).GetUses().empty());
246 
247   // The top-level tuple values are used in GTE instructions.
248   EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).GetUses(),
249               UnorderedElementsAre(HloUse{gte_out, 0, {}}));
250   EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).GetUses(),
251               UnorderedElementsAre(HloUse{gte_tuple, 0, {}}));
252 
253   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
254   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
255   EXPECT_FALSE(
256       analysis.GetValueDefinedAt(tuple, /*index=*/{}).live_out_of_module());
257   EXPECT_FALSE(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{})
258                    .live_out_of_module());
259 }
260 
TEST_P(HloDataflowAnalysisTest,SingleCall)261 TEST_P(HloDataflowAnalysisTest, SingleCall) {
262   // Test a single call of a subcomputation. The subcomputation adds its two
263   // array-shaped parameters.
264   auto subbuilder = HloComputation::Builder("Subcomputation");
265   auto subparam0 = subbuilder.AddInstruction(
266       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
267   auto subparam1 = subbuilder.AddInstruction(
268       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
269   auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary(
270       scalar_shape_, HloOpcode::kAdd, subparam0, subparam1));
271   HloComputation* called_computation =
272       module_->AddEmbeddedComputation(subbuilder.Build());
273 
274   auto builder = HloComputation::Builder(TestName());
275   auto constant1 = builder.AddInstruction(
276       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
277   auto constant2 = builder.AddInstruction(
278       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
279   auto call = builder.AddInstruction(HloInstruction::CreateCall(
280       scalar_shape_, {constant1, constant2}, called_computation));
281   module_->AddEntryComputation(builder.Build());
282   SCOPED_TRACE(module_->ToString());
283 
284   bool ssa_form = GetParam();
285   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
286 
287   EXPECT_EQ(analysis.values().size(), 3);
288 
289   // The parameters of the subcomputation and the call instruction itself should
290   // not define values. Their values flow from elsewhere.
291   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
292   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
293   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam0));
294   EXPECT_FALSE(analysis.ValueIsDefinedAt(subparam1));
295   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
296   EXPECT_FALSE(analysis.ValueIsDefinedAt(call));
297 
298   EXPECT_EQ(analysis.GetUniqueValueAt(subparam0),
299             analysis.GetValueDefinedAt(constant1));
300   EXPECT_EQ(analysis.GetUniqueValueAt(subparam1),
301             analysis.GetValueDefinedAt(constant2));
302   EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add));
303 
304   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
305               UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{add, 0, {}}));
306   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
307               UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{add, 1, {}}));
308 
309   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
310 }
311 
TEST_P(HloDataflowAnalysisTest,NestedCalls)312 TEST_P(HloDataflowAnalysisTest, NestedCalls) {
313   // Test a module with nested computations. HLO is:
314   //
315   // F32[] inner_computation(F32[] %param0, F32[] %param1):
316   //   %add = Add(%param0, %param1)
317   //
318   // F32[] outer_computation((F32[] %param0, F32[] %param1):
319   //  ;; Note that parameters are interchanged in the call.
320   //   %nested_call = Call(inner_computation, {%param1, %param0})
321   //
322   // F32[] entry:
323   //   %constant1 = Constant(1.0)
324   //   %constant2 = Constant(2.0)
325   //   %call = Call(outer_computation, {%constant1, %constant2})
326   //
327   auto inner_builder = HloComputation::Builder("InnerComputation");
328   auto inner_param0 = inner_builder.AddInstruction(
329       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
330   auto inner_param1 = inner_builder.AddInstruction(
331       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
332   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
333       scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1));
334   HloComputation* inner_computation =
335       module_->AddEmbeddedComputation(inner_builder.Build());
336 
337   auto outer_builder = HloComputation::Builder("OuterComputation");
338   auto outer_param0 = outer_builder.AddInstruction(
339       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
340   auto outer_param1 = outer_builder.AddInstruction(
341       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
342   // Swizzle parameters.
343   auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall(
344       scalar_shape_, {outer_param1, outer_param0}, inner_computation));
345   HloComputation* outer_computation =
346       module_->AddEmbeddedComputation(outer_builder.Build());
347 
348   auto builder = HloComputation::Builder(TestName());
349   auto constant1 = builder.AddInstruction(
350       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
351   auto constant2 = builder.AddInstruction(
352       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
353   auto call = builder.AddInstruction(HloInstruction::CreateCall(
354       scalar_shape_, {constant1, constant2}, outer_computation));
355   module_->AddEntryComputation(builder.Build());
356   SCOPED_TRACE(module_->ToString());
357 
358   bool ssa_form = GetParam();
359   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
360 
361   // Only three values should be defined. Most instructions just pass through
362   // their operand values.
363   EXPECT_EQ(analysis.values().size(), 3);
364 
365   // Verify that the uses of the constants are properly swizzled by parameter
366   // permutation in nested_call.
367   EXPECT_THAT(
368       analysis.GetValueDefinedAt(constant1).GetUses(),
369       UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}},
370                            HloUse{add, 1, {}}));
371   EXPECT_THAT(
372       analysis.GetValueDefinedAt(constant2).GetUses(),
373       UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}},
374                            HloUse{add, 0, {}}));
375 
376   EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
377 }
378 
TEST_P(HloDataflowAnalysisTest,SingleWhile)379 TEST_P(HloDataflowAnalysisTest, SingleWhile) {
380   // Test a simple single while instruction. The while body includes a
381   // pass-through value. HLO:
382   //
383   // body((F32[], F32[]) %tuple_param):
384   //   %add = Add(%tuple_param{0}, %tuple_param{1})
385   //   return Tuple(%tuple_param{0}, %add)
386   //
387   // condition((F32[], F32[]) %tuple_param):
388   //   return Constant(false)
389   //
390   // entry:
391   //   %constant1 = Constant(1.0)
392   //   %constant2 = Constant(2.0)
393   //   %tuple = Tuple(%constant1, %constant2)
394   //   return While(%tuple, body, condition)
395   //
396   const Shape tuple_shape =
397       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
398 
399   // Element 0 passes transparently through the body.
400   auto body_builder = HloComputation::Builder("body");
401   auto body_param = body_builder.AddInstruction(
402       HloInstruction::CreateParameter(0, tuple_shape, "param"));
403   auto body_element_0 = body_builder.AddInstruction(
404       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
405   auto body_element_1 = body_builder.AddInstruction(
406       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
407   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
408       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
409   auto body_root = body_builder.AddInstruction(
410       HloInstruction::CreateTuple({body_element_0, add}));
411   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
412 
413   // Condition computation trivially returns a constant "false".
414   auto cond_builder = HloComputation::Builder("condition");
415   auto cond_param = cond_builder.AddInstruction(
416       HloInstruction::CreateParameter(0, tuple_shape, "param"));
417   auto cond_constant = cond_builder.AddInstruction(
418       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
419   HloComputation* condition =
420       module_->AddEmbeddedComputation(cond_builder.Build());
421 
422   auto builder = HloComputation::Builder(TestName());
423   auto constant1 = builder.AddInstruction(
424       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
425   auto constant2 = builder.AddInstruction(
426       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
427   auto tuple = builder.AddInstruction(
428       HloInstruction::CreateTuple({constant1, constant2}));
429   auto xla_while = builder.AddInstruction(
430       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
431   module_->AddEntryComputation(builder.Build());
432   SCOPED_TRACE(module_->ToString());
433 
434   bool ssa_form = GetParam();
435   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
436 
437   EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
438 
439   if (ssa_form) {
440     // Element 0 of the tuple passed through the body so no phi value is
441     // defined.
442     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
443     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
444     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
445 
446     // Element 1 of the tuple should be a phi value.
447     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
448     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
449     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
450     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
451     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
452     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
453 
454     EXPECT_THAT(
455         analysis.GetValueDefinedAt(constant1).GetUses(),
456         UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{body_root, 0, {}},
457                              HloUse{xla_while, 0, {0}}));
458 
459     // Constant1 passes through the body and out of the module.
460     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
461     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
462                     .live_out_of_module());
463 
464     EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module());
465   } else {
466     // While instruction and subcomputation parameters should not define values
467     // in non-ssa form.
468     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
469     EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
470     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
471     EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
472     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
473     EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
474 
475     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
476     EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module());
477   }
478 }
479 
TEST_P(HloDataflowAnalysisTest,SequentialWhiles)480 TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
481   // Test sequential while instructions. The while body includes a
482   // pass-through value. HLO:
483   //
484   // body((F32[], F32[]) %tuple_param):
485   //   %add = Add(%tuple_param{0}, %tuple_param{1})
486   //   return Tuple(%tuple_param{0}, %add)
487   //
488   // condition((F32[], F32[]) %tuple_param):
489   //   return Constant(false)
490   //
491   // entry:
492   //   %constant1 = Constant(1.0)
493   //   %constant2 = Constant(2.0)
494   //   %tuple = Tuple(%constant1, %constant2)
495   //   %while0 = While(%tuple, body, condition)
496   //   %while1 = While(%while0, body, condition)
497   //   return While(%while1, body, condition)
498   //
499   const Shape tuple_shape =
500       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
501 
502   // Element 0 passes transparently through the body.
503   auto body_builder = HloComputation::Builder("body");
504   auto body_param = body_builder.AddInstruction(
505       HloInstruction::CreateParameter(0, tuple_shape, "param"));
506   auto body_element_0 = body_builder.AddInstruction(
507       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
508   auto body_element_1 = body_builder.AddInstruction(
509       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
510   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
511       scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
512   body_builder.AddInstruction(
513       HloInstruction::CreateTuple({body_element_0, add}));
514   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
515 
516   auto cond_builder = HloComputation::Builder("condition");
517   cond_builder.AddInstruction(
518       HloInstruction::CreateParameter(0, tuple_shape, "param"));
519   cond_builder.AddInstruction(
520       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
521   HloComputation* condition =
522       module_->AddEmbeddedComputation(cond_builder.Build());
523 
524   auto builder = HloComputation::Builder(TestName());
525   auto constant1 = builder.AddInstruction(
526       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
527   auto constant2 = builder.AddInstruction(
528       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
529   auto tuple = builder.AddInstruction(
530       HloInstruction::CreateTuple({constant1, constant2}));
531   auto xla_while0 = builder.AddInstruction(
532       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
533   auto xla_while1 = builder.AddInstruction(
534       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0));
535   auto xla_while2 = builder.AddInstruction(
536       HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1));
537   module_->AddEntryComputation(builder.Build());
538   SCOPED_TRACE(module_->ToString());
539 
540   bool ssa_form = GetParam();
541   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
542 
543   // Element 0 is passed through all the while instructions and out of the
544   // module..
545   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
546             analysis.GetValueDefinedAt(constant1));
547   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
548             analysis.GetValueDefinedAt(constant1));
549   EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
550             analysis.GetValueDefinedAt(constant1));
551   EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
552 }
553 
TEST_P(HloDataflowAnalysisTest,MultiLevelNestedWhile)554 TEST_P(HloDataflowAnalysisTest, MultiLevelNestedWhile) {
555   // Test nested while instructions. The level0 body (most inner while) and
556   // level1 body pass through the parameter, while level2 (most outer while)
557   // modifies it.
558   //
559   // level0_body((F32[]) %tuple_param):
560   //   return Tuple(%tuple_param{0})
561   //
562   // level1_body((F32[]) %tuple_param):
563   //   return While(%tuple_param{0}), body=level0
564   //
565   // level2_body((F32[]) %tuple_param):
566   //   while = While(%tuple_param{0}), body=level1
567   //.  return negate(%while{0})
568   //
569   // entry:
570   //   %constant = Constant(1.0)
571   //   %tuple = Tuple(%constant)
572   //   return While(%tuple), body=level2
573   //
574   const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_});
575   auto cond_builder = HloComputation::Builder("condition");
576   cond_builder.AddInstruction(
577       HloInstruction::CreateParameter(0, tuple_shape, "param"));
578   cond_builder.AddInstruction(
579       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
580   HloComputation* condition =
581       module_->AddEmbeddedComputation(cond_builder.Build());
582 
583   // level 0 passes transparently through the body.
584   auto level0_builder = HloComputation::Builder("level0_body");
585   auto level0_param = level0_builder.AddInstruction(
586       HloInstruction::CreateParameter(0, tuple_shape, "param"));
587   auto level0_element_0 = level0_builder.AddInstruction(
588       HloInstruction::CreateGetTupleElement(scalar_shape_, level0_param, 0));
589   auto level0_root = level0_builder.AddInstruction(
590       HloInstruction::CreateTuple({level0_element_0}));
591   HloComputation* level0_body =
592       module_->AddEmbeddedComputation(level0_builder.Build());
593 
594   // Element 1 passes transparently through the body.
595   auto level1_builder = HloComputation::Builder("level1_body");
596   auto level1_param = level1_builder.AddInstruction(
597       HloInstruction::CreateParameter(0, tuple_shape, "param"));
598   auto level1_root = level1_builder.AddInstruction(HloInstruction::CreateWhile(
599       tuple_shape, condition, level0_body, level1_param));
600   HloComputation* level1_body =
601       module_->AddEmbeddedComputation(level1_builder.Build());
602 
603   // Element 1 passes transparently through the body.
604   auto level2_builder = HloComputation::Builder("level2_body");
605   auto level2_param = level2_builder.AddInstruction(
606       HloInstruction::CreateParameter(0, tuple_shape, "param"));
607   auto level2_while = level2_builder.AddInstruction(HloInstruction::CreateWhile(
608       tuple_shape, condition, level1_body, level2_param));
609   auto level2_element_0 = level2_builder.AddInstruction(
610       HloInstruction::CreateGetTupleElement(scalar_shape_, level2_while, 0));
611   auto negate = level2_builder.AddInstruction(HloInstruction::CreateUnary(
612       scalar_shape_, HloOpcode::kNegate, level2_element_0));
613   level2_builder.AddInstruction(HloInstruction::CreateTuple({negate}));
614   HloComputation* level2_body =
615       module_->AddEmbeddedComputation(level2_builder.Build());
616 
617   auto builder = HloComputation::Builder(TestName());
618   auto constant1 = builder.AddInstruction(
619       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
620   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant1}));
621   builder.AddInstruction(
622       HloInstruction::CreateWhile(tuple_shape, condition, level2_body, tuple));
623   module_->AddEntryComputation(builder.Build());
624   SCOPED_TRACE(module_->ToString());
625 
626   bool ssa_form = GetParam();
627   if (!ssa_form) {
628     return;
629   }
630   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
631 
632   // Phi node on inner parameters and roots should have been eliminated.
633   EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_param, /*index=*/{0}));
634   EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_param, /*index=*/{0}));
635   EXPECT_FALSE(analysis.ValueIsDefinedAt(level1_root, /*index=*/{0}));
636   EXPECT_FALSE(analysis.ValueIsDefinedAt(level0_root, /*index=*/{0}));
637   EXPECT_TRUE(analysis.ValueIsDefinedAt(level2_param, /*index=*/{0}));
638   EXPECT_EQ(HloValuesAt(level1_param, /*index=*/{0}),
639             HloValuesAt(level2_param, /*index=*/{0}));
640   EXPECT_EQ(HloValuesAt(level0_param, /*index=*/{0}),
641             HloValuesAt(level2_param, /*index=*/{0}));
642   EXPECT_EQ(HloValuesAt(level1_root, /*index=*/{0}),
643             HloValuesAt(level2_param, /*index=*/{0}));
644   EXPECT_EQ(HloValuesAt(level0_root, /*index=*/{0}),
645             HloValuesAt(level2_param, /*index=*/{0}));
646 }
647 
TEST_P(HloDataflowAnalysisTest,NestedWhiles)648 TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
649   // Test nested while instructions. The inner body passes through element 0 of
650   // its parameter, and the outer body passes through element 1.  HLO:
651   //
652   // inner_body((F32[], F32[]) %tuple_param):
653   //   %add = Add(%tuple_param{0}, %tuple_param{1})
654   //   return Tuple(%tuple_param{0}, %add)
655   //
656   // outer_body((F32[], F32[]) %tuple_param):
657   //   %negate = Negate(%tuple_param{0})
658   //   %tuple = Tuple(%negate, %tuple_param{1})
659   //   return While(%tuple, inner_body, condition)
660   //
661   // entry:
662   //   %constant1 = Constant(1.0)
663   //   %constant2 = Constant(2.0)
664   //   %tuple = Tuple(%constant1, %constant2)
665   //   return While(%tuple, outer_body, condition)
666   //
667   const Shape tuple_shape =
668       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
669 
670   auto cond_builder = HloComputation::Builder("condition");
671   cond_builder.AddInstruction(
672       HloInstruction::CreateParameter(0, tuple_shape, "param"));
673   cond_builder.AddInstruction(
674       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
675   HloComputation* condition =
676       module_->AddEmbeddedComputation(cond_builder.Build());
677 
678   // Element 0 passes transparently through the body.
679   auto inner_builder = HloComputation::Builder("inner_body");
680   auto inner_param = inner_builder.AddInstruction(
681       HloInstruction::CreateParameter(0, tuple_shape, "param"));
682   auto inner_element_0 = inner_builder.AddInstruction(
683       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 0));
684   auto inner_element_1 = inner_builder.AddInstruction(
685       HloInstruction::CreateGetTupleElement(scalar_shape_, inner_param, 1));
686   auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary(
687       scalar_shape_, HloOpcode::kAdd, inner_element_0, inner_element_1));
688   inner_builder.AddInstruction(
689       HloInstruction::CreateTuple({inner_element_0, add}));
690   HloComputation* inner_body =
691       module_->AddEmbeddedComputation(inner_builder.Build());
692 
693   // Element 1 passes transparently through the body.
694   auto outer_builder = HloComputation::Builder("outer_body");
695   auto outer_param = outer_builder.AddInstruction(
696       HloInstruction::CreateParameter(0, tuple_shape, "param"));
697   auto outer_element_0 = outer_builder.AddInstruction(
698       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 0));
699   auto negate = outer_builder.AddInstruction(HloInstruction::CreateUnary(
700       scalar_shape_, HloOpcode::kNegate, outer_element_0));
701   auto outer_element_1 = outer_builder.AddInstruction(
702       HloInstruction::CreateGetTupleElement(scalar_shape_, outer_param, 1));
703   auto outer_tuple = outer_builder.AddInstruction(
704       HloInstruction::CreateTuple({negate, outer_element_1}));
705   auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile(
706       tuple_shape, condition, inner_body, outer_tuple));
707   HloComputation* outer_body =
708       module_->AddEmbeddedComputation(outer_builder.Build());
709 
710   auto builder = HloComputation::Builder(TestName());
711   auto constant1 = builder.AddInstruction(
712       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
713   auto constant2 = builder.AddInstruction(
714       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
715   auto tuple = builder.AddInstruction(
716       HloInstruction::CreateTuple({constant1, constant2}));
717   auto entry_while = builder.AddInstruction(
718       HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple));
719   module_->AddEntryComputation(builder.Build());
720   SCOPED_TRACE(module_->ToString());
721 
722   bool ssa_form = GetParam();
723   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
724 
725   EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
726               UnorderedElementsAre(&analysis.GetValueDefinedAt(negate)));
727   if (ssa_form) {
728     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
729     EXPECT_TRUE(
730         analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
731 
732     // Element 0 of the nested while is %negate.
733     EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
734     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
735                 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate)));
736     // Element 1 is a phi value (join of %add and %constant2).
737     EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
738     EXPECT_TRUE(
739         analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
740 
741     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{0}));
742     EXPECT_TRUE(
743         analysis.GetValueDefinedAt(entry_while, /*index=*/{0}).is_phi());
744 
745     EXPECT_TRUE(analysis.ValueIsDefinedAt(entry_while, /*index=*/{1}));
746     EXPECT_TRUE(
747         analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
748   } else {
749     EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
750                 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
751                                      &analysis.GetValueDefinedAt(constant2)));
752 
753     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{0}),
754                 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate)));
755     EXPECT_THAT(HloValuesAt(nested_while, /*index=*/{1}),
756                 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
757                                      &analysis.GetValueDefinedAt(constant2)));
758 
759     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{0}),
760                 UnorderedElementsAre(&analysis.GetValueDefinedAt(negate),
761                                      &analysis.GetValueDefinedAt(constant1)));
762     EXPECT_THAT(HloValuesAt(entry_while, /*index=*/{1}),
763                 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
764                                      &analysis.GetValueDefinedAt(constant2)));
765   }
766 }
767 
TEST_P(HloDataflowAnalysisTest,SwizzlingWhileSharedInput)768 TEST_P(HloDataflowAnalysisTest, SwizzlingWhileSharedInput) {
769   // Test a while instruction with a body which permutes it's tuple parameter
770   // elements. HLO:
771   //
772   // body((F32[], F32[]) %tuple_param):
773   //   return Tuple(%tuple_param{1}, %tuple_param{0})
774   //
775   // condition((F32[], F32[]) %tuple_param):
776   //   return Constant(false)
777   //
778   // entry:
779   //   %constant1 = Constant(1.0)
780   //   %tuple = Tuple(%constant1, %constant1)
781   //   return While(%tuple, body, condition)
782   //
783   const Shape tuple_shape =
784       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
785 
786   auto body_builder = HloComputation::Builder("body");
787   auto body_param = body_builder.AddInstruction(
788       HloInstruction::CreateParameter(0, tuple_shape, "param"));
789   auto body_element_0 = body_builder.AddInstruction(
790       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
791   auto body_element_1 = body_builder.AddInstruction(
792       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
793   body_builder.AddInstruction(
794       HloInstruction::CreateTuple({body_element_1, body_element_0}));
795   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
796 
797   auto cond_builder = HloComputation::Builder("condition");
798   cond_builder.AddInstruction(
799       HloInstruction::CreateParameter(0, tuple_shape, "param"));
800   cond_builder.AddInstruction(
801       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
802   HloComputation* condition =
803       module_->AddEmbeddedComputation(cond_builder.Build());
804 
805   auto builder = HloComputation::Builder(TestName());
806   auto constant1 = builder.AddInstruction(
807       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
808   auto tuple = builder.AddInstruction(
809       HloInstruction::CreateTuple({constant1, constant1}));
810   builder.AddInstruction(
811       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
812   module_->AddEntryComputation(builder.Build());
813   SCOPED_TRACE(module_->ToString());
814 
815   bool ssa_form = GetParam();
816   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
817   EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
818 }
819 
TEST_P(HloDataflowAnalysisTest,SwizzlingWhile)820 TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) {
821   // Test a while instruction with a body which permutes it's tuple parameter
822   // elements. HLO:
823   //
824   // body((F32[], F32[]) %tuple_param):
825   //   return Tuple(%tuple_param{1}, %tuple_param{0})
826   //
827   // condition((F32[], F32[]) %tuple_param):
828   //   return Constant(false)
829   //
830   // entry:
831   //   %constant1 = Constant(1.0)
832   //   %constant2 = Constant(2.0)
833   //   %tuple = Tuple(%constant1, %constant2)
834   //   return While(%tuple, body, condition)
835   //
836   const Shape tuple_shape =
837       ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
838 
839   auto body_builder = HloComputation::Builder("body");
840   auto body_param = body_builder.AddInstruction(
841       HloInstruction::CreateParameter(0, tuple_shape, "param"));
842   auto body_element_0 = body_builder.AddInstruction(
843       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
844   auto body_element_1 = body_builder.AddInstruction(
845       HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
846   body_builder.AddInstruction(
847       HloInstruction::CreateTuple({body_element_1, body_element_0}));
848   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
849 
850   auto cond_builder = HloComputation::Builder("condition");
851   auto cond_param = cond_builder.AddInstruction(
852       HloInstruction::CreateParameter(0, tuple_shape, "param"));
853   cond_builder.AddInstruction(
854       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
855   HloComputation* condition =
856       module_->AddEmbeddedComputation(cond_builder.Build());
857 
858   auto builder = HloComputation::Builder(TestName());
859   auto constant1 = builder.AddInstruction(
860       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
861   auto constant2 = builder.AddInstruction(
862       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
863   auto tuple = builder.AddInstruction(
864       HloInstruction::CreateTuple({constant1, constant2}));
865   auto xla_while = builder.AddInstruction(
866       HloInstruction::CreateWhile(tuple_shape, condition, body, tuple));
867   module_->AddEntryComputation(builder.Build());
868   SCOPED_TRACE(module_->ToString());
869 
870   bool ssa_form = GetParam();
871   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
872 
873   if (ssa_form) {
874     // Element 0 and 1 in the while should both be phi values.
875     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
876     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
877     EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
878     EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
879 
880     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
881     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
882     EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
883     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
884 
885     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
886     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
887     EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
888     EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
889 
890     EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
891     EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
892     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{})
893                     .live_out_of_module());
894     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
895                     .live_out_of_module());
896     EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
897                     .live_out_of_module());
898   } else {
899     // Elements 0 and 1 have both constants as reaching definitions.
900     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{0}),
901                 UnorderedElementsAre(&analysis.GetValueDefinedAt(constant1),
902                                      &analysis.GetValueDefinedAt(constant2)));
903     EXPECT_THAT(HloValuesAt(xla_while, /*index=*/{1}),
904                 UnorderedElementsAre(&analysis.GetValueDefinedAt(constant1),
905                                      &analysis.GetValueDefinedAt(constant2)));
906     EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
907     EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
908   }
909 }
910 
TEST_P(HloDataflowAnalysisTest,ArraySelect)911 TEST_P(HloDataflowAnalysisTest, ArraySelect) {
912   // Test a kSelect of an array value.
913   auto builder = HloComputation::Builder(TestName());
914   auto pred = builder.AddInstruction(
915       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
916   auto constant1 = builder.AddInstruction(
917       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
918   auto constant2 = builder.AddInstruction(
919       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
920   auto select = builder.AddInstruction(HloInstruction::CreateTernary(
921       scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2));
922 
923   module_->AddEntryComputation(builder.Build());
924   SCOPED_TRACE(module_->ToString());
925 
926   bool ssa_form = GetParam();
927   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
928 
929   EXPECT_TRUE(analysis.ValueIsDefinedAt(select));
930   EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
931   EXPECT_FALSE(analysis.GetValueDefinedAt(constant2).live_out_of_module());
932   EXPECT_TRUE(analysis.GetValueDefinedAt(select).live_out_of_module());
933 }
934 
TEST_P(HloDataflowAnalysisTest,BitcastDefinesValue)935 TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) {
936   // Test the bitcast_defines_value flag to the dataflow analysis.
937   auto builder = HloComputation::Builder(TestName());
938   auto constant = builder.AddInstruction(
939       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
940   auto bitcast = builder.AddInstruction(
941       HloInstruction::CreateBitcast(scalar_shape_, constant));
942 
943   module_->AddEntryComputation(builder.Build());
944   SCOPED_TRACE(module_->ToString());
945 
946   bool ssa_form = GetParam();
947   {
948     const HloDataflowAnalysis& analysis =
949         RunAnalysis(ssa_form, /*bitcast_defines_value=*/true);
950 
951     EXPECT_EQ(analysis.values().size(), 2);
952 
953     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
954     EXPECT_TRUE(analysis.ValueIsDefinedAt(bitcast));
955     EXPECT_FALSE(analysis.GetValueDefinedAt(constant).live_out_of_module());
956     EXPECT_TRUE(analysis.GetValueDefinedAt(bitcast).live_out_of_module());
957   }
958   {
959     const HloDataflowAnalysis& analysis =
960         RunAnalysis(ssa_form, /*bitcast_defines_value=*/false);
961     EXPECT_EQ(analysis.values().size(), 1);
962 
963     EXPECT_TRUE(analysis.ValueIsDefinedAt(constant));
964     EXPECT_FALSE(analysis.ValueIsDefinedAt(bitcast));
965     EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
966   }
967 }
968 
TEST_P(HloDataflowAnalysisTest,TupleCopy)969 TEST_P(HloDataflowAnalysisTest, TupleCopy) {
970   // Test that a tuple-shaped copy only copies (defines) the top-level value.
971   auto builder = HloComputation::Builder(TestName());
972   auto param0 = builder.AddInstruction(
973       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
974   auto param1 = builder.AddInstruction(
975       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
976   auto tuple =
977       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
978   auto copy = builder.AddInstruction(
979       HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
980   module_->AddEntryComputation(builder.Build());
981   SCOPED_TRACE(module_->ToString());
982 
983   bool ssa_form = GetParam();
984   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
985 
986   EXPECT_EQ(analysis.values().size(), 4);
987 
988   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
989   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
990   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
991   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
992   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
993   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy, /*index=*/{}));
994   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{0}));
995   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy, /*index=*/{1}));
996 
997   EXPECT_THAT(HloValuesAt(copy, /*index=*/{0}),
998               UnorderedElementsAre(&analysis.GetValueDefinedAt(param0)));
999   EXPECT_THAT(HloValuesAt(copy, /*index=*/{1}),
1000               UnorderedElementsAre(&analysis.GetValueDefinedAt(param1)));
1001   EXPECT_TRUE(
1002       analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module());
1003 }
1004 
TEST_P(HloDataflowAnalysisTest,OptimizationBarrier)1005 TEST_P(HloDataflowAnalysisTest, OptimizationBarrier) {
1006   // Test that an optimization barrier is a nop.
1007   auto builder = HloComputation::Builder(TestName());
1008   auto param0 = builder.AddInstruction(
1009       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1010   auto param1 = builder.AddInstruction(
1011       HloInstruction::CreateParameter(1, scalar_shape_, "param1"));
1012   auto tuple =
1013       builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
1014   auto barrier = builder.AddInstruction(HloInstruction::CreateUnary(
1015       tuple->shape(), HloOpcode::kOptimizationBarrier, tuple));
1016   module_->AddEntryComputation(builder.Build());
1017   SCOPED_TRACE(module_->ToString());
1018 
1019   bool ssa_form = GetParam();
1020   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1021 
1022   EXPECT_EQ(analysis.values().size(), 3);
1023 
1024   EXPECT_TRUE(analysis.ValueIsDefinedAt(param0));
1025   EXPECT_TRUE(analysis.ValueIsDefinedAt(param1));
1026   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple, /*index=*/{}));
1027   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{0}));
1028   EXPECT_FALSE(analysis.ValueIsDefinedAt(tuple, /*index=*/{1}));
1029   EXPECT_FALSE(analysis.ValueIsDefinedAt(barrier, /*index=*/{}));
1030   EXPECT_FALSE(analysis.ValueIsDefinedAt(barrier, /*index=*/{0}));
1031   EXPECT_FALSE(analysis.ValueIsDefinedAt(barrier, /*index=*/{1}));
1032 
1033   EXPECT_THAT(HloValuesAt(barrier, /*index=*/{0}),
1034               UnorderedElementsAre(&analysis.GetValueDefinedAt(param0)));
1035   EXPECT_THAT(HloValuesAt(barrier, /*index=*/{1}),
1036               UnorderedElementsAre(&analysis.GetValueDefinedAt(param1)));
1037 }
1038 
TEST_P(HloDataflowAnalysisTest,CopyStartAndCopyDone)1039 TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
1040   // Test that a CopyDone forwards its operand tuple element at {0} to the
1041   // output.
1042   auto builder = HloComputation::Builder(TestName());
1043   auto constant = builder.AddInstruction(
1044       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1045   auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
1046       ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
1047                                  ShapeUtil::MakeShape(U32, {})}),
1048       constant));
1049   auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
1050       constant->shape(), HloOpcode::kCopyDone, copy_start));
1051   module_->AddEntryComputation(builder.Build());
1052   SCOPED_TRACE(module_->ToString());
1053 
1054   bool ssa_form = GetParam();
1055   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1056 
1057   EXPECT_EQ(analysis.values().size(), 4);
1058 
1059   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{}));
1060   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{0}));
1061   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{1}));
1062   EXPECT_TRUE(analysis.ValueIsDefinedAt(copy_start, /*index=*/{2}));
1063   EXPECT_FALSE(analysis.ValueIsDefinedAt(copy_done, /*index=*/{}));
1064   EXPECT_THAT(
1065       HloValuesAt(copy_done, /*index=*/{}),
1066       UnorderedElementsAre(&analysis.GetValueDefinedAt(copy_start, {0})));
1067   EXPECT_TRUE(analysis.GetValueDefinedAt(copy_start, /*index=*/{0})
1068                   .live_out_of_module());
1069 }
1070 
TEST_P(HloDataflowAnalysisTest,AsyncOps)1071 TEST_P(HloDataflowAnalysisTest, AsyncOps) {
1072   std::string hlo_str = R"(
1073   HloModule module
1074 
1075   ENTRY entry {
1076     p0 = f32[2,3] parameter(0)
1077     async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
1078     async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo"
1079     ROOT async-done = f32[2,3] custom-call-done(async-update), custom_call_target="foo"
1080   }
1081 )";
1082   TF_ASSERT_OK_AND_ASSIGN(
1083       module_, ParseAndReturnVerifiedModule(hlo_str, GetModuleConfigForTest()));
1084 
1085   bool ssa_form = GetParam();
1086   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1087 
1088   const HloInstruction* param =
1089       module_->entry_computation()->parameter_instruction(0);
1090   const HloInstruction* async_start =
1091       FindInstruction(module_.get(), "async-start");
1092   const HloInstruction* async_update =
1093       FindInstruction(module_.get(), "async-update");
1094   const HloInstruction* async_done =
1095       FindInstruction(module_.get(), "async-done");
1096   const HloInstruction* async_wrapped_instruction =
1097       async_start->async_wrapped_instruction();
1098 
1099   EXPECT_TRUE(analysis.ValueIsDefinedAt(async_start, /*index=*/{}));
1100   EXPECT_FALSE(analysis.ValueIsDefinedAt(async_start, /*index=*/{0, 0}));
1101   EXPECT_FALSE(analysis.ValueIsDefinedAt(async_start, /*index=*/{1}));
1102   EXPECT_THAT(HloValuesAt(async_start, {1}),
1103               UnorderedElementsAre(
1104                   &analysis.GetValueDefinedAt(async_wrapped_instruction, {})));
1105   EXPECT_TRUE(analysis.ValueIsDefinedAt(async_start, /*index=*/{2}));
1106   EXPECT_THAT(HloValuesAt(async_start, /*index=*/{0, 0}),
1107               UnorderedElementsAre(&analysis.GetValueDefinedAt(param, {})));
1108   EXPECT_TRUE(analysis.GetValueDefinedAt(async_wrapped_instruction, {})
1109                   .live_out_of_module());
1110 
1111   EXPECT_TRUE(analysis.ValueIsDefinedAt(async_update, /*index=*/{}));
1112   EXPECT_FALSE(analysis.ValueIsDefinedAt(async_update, /*index=*/{0, 0}));
1113   EXPECT_FALSE(analysis.ValueIsDefinedAt(async_update, /*index=*/{1}));
1114   EXPECT_FALSE(analysis.ValueIsDefinedAt(async_update, /*index=*/{2}));
1115   EXPECT_THAT(HloValuesAt(async_update, /*index=*/{0, 0}),
1116               UnorderedElementsAre(&analysis.GetValueDefinedAt(param, {})));
1117   EXPECT_THAT(HloValuesAt(async_update, /*index=*/{1}),
1118               UnorderedElementsAre(
1119                   &analysis.GetValueDefinedAt(async_wrapped_instruction, {})));
1120 
1121   EXPECT_FALSE(analysis.ValueIsDefinedAt(async_done, /*index=*/{}));
1122   EXPECT_THAT(HloValuesAt(async_done, /*index=*/{}),
1123               UnorderedElementsAre(
1124                   &analysis.GetValueDefinedAt(async_wrapped_instruction, {})));
1125 }
1126 
TEST_P(HloDataflowAnalysisTest,AsyncCall)1127 TEST_P(HloDataflowAnalysisTest, AsyncCall) {
1128   std::string hlo_str = R"(
1129 HloModule AsyncCall
1130 
1131 %called_computation (param_0: f32[4096], param_1: f32[4096]) -> f32[4096] {
1132   %param_0 = f32[4096]{0} parameter(0)
1133   %param_1 = f32[4096]{0} parameter(1)
1134   %negate_0 = f32[4096]{0} negate(f32[4096]{0} %param_0)
1135   %negate_1 = f32[4096]{0} negate(f32[4096]{0} %param_1)
1136   ROOT %result.1 = f32[4096]{0} add(f32[4096]{0} %negate_0, f32[4096]{0} %negate_1)
1137 }
1138 
1139 ENTRY %main (a: f32[4096], b: f32[4096]) -> f32[4096] {
1140   %a = f32[4096]{0} parameter(0)
1141   %b = f32[4096]{0} parameter(1)
1142   %async-start = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-start(f32[4096]{0} %a, f32[4096]{0} %b), to_apply=%called_computation
1143   %negate_2 = f32[4096]{0} negate(f32[4096]{0} %a)
1144   %async-update = ((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) call-update(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-start), to_apply=%called_computation
1145   %negate_3 = f32[4096]{0} negate(f32[4096]{0} %b)
1146   %add_0 = f32[4096]{0} add(f32[4096]{0} %negate_2, f32[4096]{0} %negate_3)
1147   %async-done = f32[4096]{0} call-done(((f32[4096]{0}, f32[4096]{0}), f32[4096]{0}, u32[]) %async-update), to_apply=%called_computation
1148   ROOT %add_1 = f32[4096]{0} add(f32[4096]{0} %add_0, f32[4096]{0} %async-done)
1149 }
1150 )";
1151   TF_ASSERT_OK_AND_ASSIGN(
1152       module_, ParseAndReturnVerifiedModule(hlo_str, GetModuleConfigForTest()));
1153 
1154   bool ssa_form = GetParam();
1155   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1156 
1157   const HloInstruction* a = FindInstruction(module_.get(), "a");
1158   const HloInstruction* b = FindInstruction(module_.get(), "b");
1159   const HloInstruction* async_done =
1160       FindInstruction(module_.get(), "async-done");
1161 
1162   // For each of the async operations, ensure the called computation
1163   // parameter/root instructions have the same HloValues as the callees.
1164   for (std::string async_name : {"async-start", "async-update", "async-done"}) {
1165     const HloInstruction* async_op = FindInstruction(module_.get(), async_name);
1166     const HloComputation* called_computation =
1167         async_op->async_wrapped_instruction()->called_computations()[0];
1168     const HloInstruction* parameter0 =
1169         called_computation->parameter_instruction(0);
1170     EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter0));
1171     EXPECT_THAT(HloValuesAt(parameter0),
1172                 UnorderedElementsAre(&analysis.GetValueDefinedAt(a)));
1173     const HloInstruction* parameter1 =
1174         called_computation->parameter_instruction(1);
1175     EXPECT_FALSE(analysis.ValueIsDefinedAt(parameter1));
1176     EXPECT_THAT(HloValuesAt(parameter1),
1177                 UnorderedElementsAre(&analysis.GetValueDefinedAt(b)));
1178     const HloInstruction* root = called_computation->root_instruction();
1179     EXPECT_TRUE(analysis.ValueIsDefinedAt(root));
1180     EXPECT_THAT(HloValuesAt(async_done),
1181                 UnorderedElementsAre(&analysis.GetValueDefinedAt(root)));
1182   }
1183 }
1184 
TEST_P(HloDataflowAnalysisTest,SendAndSendDone)1185 TEST_P(HloDataflowAnalysisTest, SendAndSendDone) {
1186   // Test that a Send forwards its operand to the output tuple at {0}.
1187   auto builder = HloComputation::Builder(TestName());
1188   auto param = builder.AddInstruction(
1189       HloInstruction::CreateParameter(0, scalar_shape_, "param0"));
1190   auto token = builder.AddInstruction(HloInstruction::CreateToken());
1191   auto send = builder.AddInstruction(
1192       HloInstruction::CreateSend(param, token, /*channel_id=*/0));
1193   auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
1194   module_->AddEntryComputation(builder.Build());
1195   SCOPED_TRACE(module_->ToString());
1196 
1197   bool ssa_form = GetParam();
1198   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1199 
1200   EXPECT_EQ(analysis.values().size(), 6);
1201 
1202   EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1203   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{}));
1204   EXPECT_FALSE(analysis.ValueIsDefinedAt(send, /*index=*/{0}));
1205   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{1}));
1206   EXPECT_TRUE(analysis.ValueIsDefinedAt(send, /*index=*/{2}));
1207   EXPECT_TRUE(analysis.ValueIsDefinedAt(send_done));
1208   EXPECT_THAT(HloValuesAt(send, /*index=*/{0}),
1209               UnorderedElementsAre(&analysis.GetValueDefinedAt(param)));
1210 }
1211 
TEST_P(HloDataflowAnalysisTest,SetDimensionSizeForwardsValue)1212 TEST_P(HloDataflowAnalysisTest, SetDimensionSizeForwardsValue) {
1213   auto builder = HloComputation::Builder(TestName());
1214   auto param = builder.AddInstruction(
1215       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1216   auto size = builder.AddInstruction(
1217       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(3)));
1218   auto sds = builder.AddInstruction(
1219       HloInstruction::CreateSetDimensionSize(vector_shape_, param, size, 0));
1220 
1221   module_->AddEntryComputation(builder.Build());
1222   SCOPED_TRACE(module_->ToString());
1223 
1224   bool ssa_form = GetParam();
1225   {
1226     const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1227     EXPECT_EQ(analysis.values().size(), 2);
1228 
1229     EXPECT_TRUE(analysis.ValueIsDefinedAt(param));
1230     EXPECT_FALSE(analysis.ValueIsDefinedAt(sds));
1231     EXPECT_TRUE(analysis.GetValueDefinedAt(param).live_out_of_module());
1232   }
1233 }
1234 
TEST_P(HloDataflowAnalysisTest,RecvAndRecvDone)1235 TEST_P(HloDataflowAnalysisTest, RecvAndRecvDone) {
1236   // Test that a RecvDone forwards its operand tuple element at {0} to element
1237   // {0} of the output.
1238   auto builder = HloComputation::Builder(TestName());
1239   auto token = builder.AddInstruction(HloInstruction::CreateToken());
1240   auto recv = builder.AddInstruction(
1241       HloInstruction::CreateRecv(scalar_shape_, token, /*channel_id=*/0));
1242   auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
1243   module_->AddEntryComputation(builder.Build());
1244   SCOPED_TRACE(module_->ToString());
1245 
1246   bool ssa_form = GetParam();
1247   const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
1248 
1249   EXPECT_EQ(analysis.values().size(), 7);
1250 
1251   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{}));
1252   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{0}));
1253   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{1}));
1254   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv, /*index=*/{2}));
1255   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{}));
1256   EXPECT_FALSE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{0}));
1257   EXPECT_TRUE(analysis.ValueIsDefinedAt(recv_done, /*index=*/{1}));
1258   EXPECT_THAT(HloValuesAt(recv_done, /*index=*/{0}),
1259               UnorderedElementsAre(&analysis.GetValueDefinedAt(recv, {0})));
1260   EXPECT_TRUE(
1261       analysis.GetValueDefinedAt(recv, /*index=*/{0}).live_out_of_module());
1262 }
1263 
TEST_P(HloDataflowAnalysisTest,ElementwiseChainInterference)1264 TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) {
1265   // A simple chain of elementwise operations. No values should interfere.
1266   //
1267   // param --> negate -> exp -> log
1268   //
1269   auto builder = HloComputation::Builder(TestName());
1270   auto param = builder.AddInstruction(
1271       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1272   auto negate = builder.AddInstruction(
1273       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1274   auto exp = builder.AddInstruction(
1275       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate));
1276   auto log = builder.AddInstruction(
1277       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp));
1278 
1279   module_->AddEntryComputation(builder.Build());
1280   SCOPED_TRACE(module_->ToString());
1281   RunAnalysis(GetParam());
1282 
1283   DependencyHloOrdering ordering(module_.get());
1284 
1285   // No values should interfere.
1286   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1287   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1288   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log));
1289   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp));
1290   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log));
1291   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1292   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log));
1293   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate));
1294   EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp));
1295 
1296   // Values should interfere with itself.
1297   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp));
1298 }
1299 
TEST_P(HloDataflowAnalysisTest,MultipleEntryParameters_Sequential)1300 TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) {
1301   // Two entry params, which interfere with each other.
1302   //
1303   // param0 --> negate ---------------\
1304   //                param1 --> exp --> add
1305   auto builder = HloComputation::Builder(TestName());
1306   auto param0 = builder.AddInstruction(
1307       HloInstruction::CreateParameter(0, vector_shape_, "param0"));
1308   auto param1 = builder.AddInstruction(
1309       HloInstruction::CreateParameter(1, vector_shape_, "param1"));
1310   auto negate = builder.AddInstruction(
1311       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0));
1312   auto exp = builder.AddInstruction(
1313       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1));
1314   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1315       vector_shape_, HloOpcode::kAdd, negate, exp));
1316 
1317   auto entry = module_->AddEntryComputation(builder.Build());
1318   SCOPED_TRACE(module_->ToString());
1319   RunAnalysis(GetParam());
1320 
1321   HloSchedule schedule(module_.get());
1322   schedule.set_sequence(entry, {param0, negate, param1, exp, add});
1323   TF_ASSERT_OK(schedule.Verify());
1324   SequentialHloOrdering ordering(schedule);
1325 
1326   // Entry parameters interfere as if they are defined simultaneously at
1327   // the very beginning.
1328   EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1));
1329   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate));
1330   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp));
1331   EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add));
1332   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0));
1333   EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate));
1334   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp));
1335   EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add));
1336 
1337   // Negate and exp still interfere.
1338   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1339   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1340 
1341   // But {negate, add} and {exp, add} don't interfere.
1342   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1343   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1344   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1345   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1346 }
1347 
TEST_P(HloDataflowAnalysisTest,WhileParameters_Sequential)1348 TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) {
1349   // Similar to MultipleEntryParameters_Sequential, but the parameter is of
1350   // while body computation. Body computation in the sequential order:
1351   //
1352   //  %constant = Constant(...)
1353   //  %exp = Exp(%constant)
1354   //  %param = Param(0)
1355   //  %add = Add(%param, %exp)  ;; Root of body
1356   //  %dead_constant = Constant(...)
1357   //  %dead_negate = Negate(%dead_constant)
1358   //
1359   // %constant and its only use %exp are ordered before 'param'. However, the
1360   // %constant and %param values still interfere because the parameter is
1361   // considered live into the while body.
1362   //
1363   // Similarly, %dead_constant and %dead_negate are ordered after the root of
1364   // the body computation %add. However, %add is liveout of the computation so
1365   // %dead_constant and %add interfere.
1366   auto body_builder = HloComputation::Builder(TestName());
1367   auto body_param = body_builder.AddInstruction(
1368       HloInstruction::CreateParameter(0, scalar_shape_, "body_param"));
1369   auto constant = body_builder.AddInstruction(
1370       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1371   auto exp = body_builder.AddInstruction(
1372       HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
1373   auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
1374       scalar_shape_, HloOpcode::kAdd, exp, body_param));
1375   auto dead_constant = body_builder.AddInstruction(
1376       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1377   auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
1378       scalar_shape_, HloOpcode::kNegate, dead_constant));
1379   HloComputation* body = module_->AddEmbeddedComputation(
1380       body_builder.Build(/*root_instruction=*/add));
1381 
1382   auto cond_builder = HloComputation::Builder("condition");
1383   auto cond_param = cond_builder.AddInstruction(
1384       HloInstruction::CreateParameter(0, scalar_shape_, "cond_param"));
1385   auto cond_constant = cond_builder.AddInstruction(
1386       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1387   HloComputation* condition =
1388       module_->AddEmbeddedComputation(cond_builder.Build());
1389 
1390   auto builder = HloComputation::Builder(TestName());
1391   auto param = builder.AddInstruction(
1392       HloInstruction::CreateParameter(0, scalar_shape_, "param"));
1393   auto xla_while = builder.AddInstruction(
1394       HloInstruction::CreateWhile(scalar_shape_, condition, body, param));
1395 
1396   auto entry = module_->AddEntryComputation(builder.Build());
1397   SCOPED_TRACE(module_->ToString());
1398   bool ssa_form = GetParam();
1399   RunAnalysis(ssa_form, /*bitcast_defines_value=*/false,
1400               /*run_dce=*/false);
1401 
1402   HloSchedule schedule(module_.get());
1403   schedule.set_sequence(entry, {param, xla_while});
1404   schedule.set_sequence(condition, {cond_param, cond_constant});
1405   // Construct the order such that 'constant' and its use 'exp' are before
1406   // body_param.
1407   schedule.set_sequence(
1408       body, {constant, exp, body_param, add, dead_constant, dead_negate});
1409   TF_ASSERT_OK(schedule.Verify());
1410 
1411   SequentialHloOrdering ordering(schedule);
1412 
1413   // 'add' is live out of the body and will interfere with an later instructions
1414   // such as 'dead_constant' and 'dead_negate'.
1415   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant));
1416   EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate));
1417 
1418   // The remaining checks test phi values defined by body and condition
1419   // parameters which only occur in the SSA form of the analysis.
1420   if (ssa_form) {
1421     // Though the ordering suggests 'constant' and 'param' should not interfere,
1422     // 'param' is live in and thus interferes with any earlier instruction of
1423     // the computation in the order (eg 'constant')'
1424     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant));
1425     EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp));
1426     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1427 
1428     // The following values end up in the same buffer:
1429     //  (1) the init value: 'param'
1430     //  (2) the body parameter: 'body_param'
1431     //  (3) the condition parameter: 'cond_param'
1432     //  (4) the root value of the while body: 'add'
1433     //  (5) the while value: 'xla_while'
1434     // None should interfere.
1435     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param));
1436     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param));
1437     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1438     EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while));
1439 
1440     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param));
1441     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add));
1442     EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while));
1443 
1444     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add));
1445     EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while));
1446 
1447     EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while));
1448   }
1449 }
1450 
TEST_P(HloDataflowAnalysisTest,NonElementwiseOperand)1451 TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) {
1452   // A chain of operations with two elementwise and one non-elementwise. The
1453   // elementwise op should not interfere with its operand, while the
1454   // non-elementwise op should interfere. Entry params always interfere.
1455   //
1456   // param --> exp -> negate -> reverse
1457   //
1458   auto builder = HloComputation::Builder(TestName());
1459   auto param = builder.AddInstruction(
1460       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1461   auto exp = builder.AddInstruction(
1462       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1463   auto negate = builder.AddInstruction(
1464       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp));
1465   auto reverse = builder.AddInstruction(
1466       HloInstruction::CreateReverse(vector_shape_, negate, {0}));
1467 
1468   module_->AddEntryComputation(builder.Build());
1469   SCOPED_TRACE(module_->ToString());
1470   RunAnalysis(GetParam());
1471 
1472   DependencyHloOrdering ordering(module_.get());
1473 
1474   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1475   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate));
1476   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse));
1477 
1478   // Negate is elementwise, so doesn't interfere with its operand.
1479   // Reverse is non-elementwise, so does interfere with its operand.
1480   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate));
1481   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse));
1482 }
1483 
TEST_P(HloDataflowAnalysisTest,OverlappedValues)1484 TEST_P(HloDataflowAnalysisTest, OverlappedValues) {
1485   // Verify simultaneously live values interfere (exp and negate).
1486   //
1487   // param --> negate -> add
1488   //     \---> exp -----/
1489   //
1490   auto builder = HloComputation::Builder(TestName());
1491   auto param = builder.AddInstruction(
1492       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1493   auto negate = builder.AddInstruction(
1494       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1495   auto exp = builder.AddInstruction(
1496       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1497   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1498       vector_shape_, HloOpcode::kAdd, negate, exp));
1499 
1500   module_->AddEntryComputation(builder.Build());
1501   SCOPED_TRACE(module_->ToString());
1502   RunAnalysis(GetParam());
1503 
1504   DependencyHloOrdering ordering(module_.get());
1505 
1506   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1507   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp));
1508   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1509 
1510   // Negate and exp interfere with each other, but not with add.
1511   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1512   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1513   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1514   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1515   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1516   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1517 }
1518 
TEST_P(HloDataflowAnalysisTest,OverlappedValuesSequentialOrder)1519 TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) {
1520   // Identical to the test OverlappedValue but using a sequential ordering of
1521   // HLO instructions.
1522   //
1523   // param --> negate -> add
1524   //     \---> exp -----/
1525   //
1526   // Sequential order:
1527   //  param, negate, exp, add
1528   //
1529   // Liveness is identical to the DependencyHloOrdering.
1530   auto builder = HloComputation::Builder(TestName());
1531   auto param = builder.AddInstruction(
1532       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1533   auto negate = builder.AddInstruction(
1534       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1535   auto exp = builder.AddInstruction(
1536       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1537   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1538       vector_shape_, HloOpcode::kAdd, negate, exp));
1539 
1540   auto entry = module_->AddEntryComputation(builder.Build());
1541   SCOPED_TRACE(module_->ToString());
1542   RunAnalysis(GetParam());
1543 
1544   HloSchedule schedule(module_.get());
1545   schedule.set_sequence(entry, {param, negate, exp, add});
1546   TF_ASSERT_OK(schedule.Verify());
1547   SequentialHloOrdering ordering(schedule);
1548 
1549   EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate));
1550   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp));
1551   EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add));
1552 
1553   // Negate and exp interfere with each other, but not with add.
1554   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp));
1555   EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate));
1556   EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add));
1557   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate));
1558   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add));
1559   EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp));
1560 }
1561 
TEST_P(HloDataflowAnalysisTest,EmbeddedComputationInterference)1562 TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
1563   // Test MayInterfere() for embedded computation, specifically the interference
1564   // of values in different computations.
1565   //
1566   // embedded_computation:
1567   //   %embedded_param = Param(0)
1568   //   %embedded_log = Log(%embedded_param)
1569   //
1570   // entry computation:
1571   //   %param = Param(0)
1572   //   %negate = Negate(%param)
1573   //   %exp = Negate(%exp)
1574   //   %call = Call(embedded_computation, {%exp})
1575   //   %add = Add(%negate, %call)
1576   //
1577   // Note %negate is live across the call and should interfere with all values
1578   // in the embedded computation.
1579   auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
1580   auto embedded_param = embedded_builder.AddInstruction(
1581       HloInstruction::CreateParameter(0, vector_shape_, "embedded_param"));
1582   auto embedded_log =
1583       embedded_builder.AddInstruction(HloInstruction::CreateUnary(
1584           vector_shape_, HloOpcode::kLog, embedded_param));
1585   auto embedded_computation =
1586       module_->AddEmbeddedComputation(embedded_builder.Build());
1587 
1588   auto builder = HloComputation::Builder(TestName());
1589   auto param = builder.AddInstruction(
1590       HloInstruction::CreateParameter(0, vector_shape_, "param"));
1591   auto negate = builder.AddInstruction(
1592       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param));
1593   auto exp = builder.AddInstruction(
1594       HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param));
1595   auto call = builder.AddInstruction(
1596       HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation));
1597   builder.AddInstruction(HloInstruction::CreateBinary(
1598       vector_shape_, HloOpcode::kAdd, negate, call));
1599   module_->AddEntryComputation(builder.Build());
1600   SCOPED_TRACE(module_->ToString());
1601   RunAnalysis(GetParam());
1602 
1603   DependencyHloOrdering ordering(module_.get());
1604 
1605   // Exp only use is the call so it should not interfere with values inside
1606   // the embedded computation.
1607   EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log));
1608 
1609   // Negate is live across the call and should interfere with values in the
1610   // embedded computation
1611   EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
1612 }
1613 
TEST_P(HloDataflowAnalysisTest,GetFlattenedValueSet)1614 TEST_P(HloDataflowAnalysisTest, GetFlattenedValueSet) {
1615   const char* hlo_text = R"(
1616 HloModule test_aliasing_module
1617 
1618 ENTRY root {
1619   param = s32[1000] parameter(0)
1620   p0 = s32[1000] copy(param)
1621   p1 = s32[1000] copy(param)
1622   ROOT t = (s32[1000], s32[1000]) tuple(p0, p1)
1623   })";
1624   TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_text));
1625   auto entry = module_->entry_computation();
1626   entry->GetInstructionWithName("t");
1627   auto& dataflow_analysis = RunAnalysis(GetParam());
1628   auto set = dataflow_analysis.GetFlattenedValueSet(
1629       entry->GetInstructionWithName("t"));
1630   EXPECT_EQ(set.values().size(), 3);
1631 }
1632 
TEST_P(HloDataflowAnalysisTest,ConditionalWithIdentity)1633 TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) {
1634   // Test conditional with identity computations in both true and false cases.
1635   //
1636   // true_computation(F32[] %true_param):
1637   //   return %true_param
1638   //
1639   // false_computation(F32[] %false_param):
1640   //   return %false_param
1641   //
1642   // entry:
1643   //   %pred = Constant(true)
1644   //   %constant1 = Constant(56.0)
1645   //   %constant2 = Constant(12.0)
1646   //   return Conditional(%pred, %constant1, true_computation,
1647   //                      %constant2, false_computation)
1648 
1649   auto true_builder = HloComputation::Builder(TestName() + "_true");
1650   auto true_param = true_builder.AddInstruction(
1651       HloInstruction::CreateParameter(0, scalar_shape_, "true_param"));
1652   HloComputation* true_computation =
1653       module_->AddEmbeddedComputation(true_builder.Build());
1654 
1655   auto false_builder = HloComputation::Builder(TestName() + "_false");
1656   auto false_param = false_builder.AddInstruction(
1657       HloInstruction::CreateParameter(0, scalar_shape_, "false_param"));
1658   HloComputation* false_computation =
1659       module_->AddEmbeddedComputation(false_builder.Build());
1660 
1661   auto builder = HloComputation::Builder(TestName());
1662   auto pred = builder.AddInstruction(
1663       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1664   auto constant1 = builder.AddInstruction(
1665       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1666   auto constant2 = builder.AddInstruction(
1667       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1668   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1669       scalar_shape_, pred, constant1, true_computation, constant2,
1670       false_computation));
1671   module_->AddEntryComputation(builder.Build());
1672   SCOPED_TRACE(module_->ToString());
1673 
1674   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1675 
1676   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1677   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1678   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1679 
1680   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1681   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1682 
1683   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1684             analysis.GetValueDefinedAt(constant1));
1685   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1686             analysis.GetValueDefinedAt(constant2));
1687 
1688   EXPECT_THAT(analysis.GetValueDefinedAt(pred).GetUses(),
1689               ElementsAre(HloUse{conditional, 0, {}}));
1690   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
1691               ElementsAre(HloUse{conditional, 1, {}}));
1692   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
1693               ElementsAre(HloUse{conditional, 2, {}}));
1694 
1695   bool ssa_form = GetParam();
1696   if (ssa_form) {
1697     EXPECT_EQ(analysis.values().size(), 4);
1698     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1699   } else {
1700     EXPECT_EQ(analysis.values().size(), 3);
1701     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1702     EXPECT_THAT(HloValuesAt(conditional),
1703                 UnorderedElementsAre(&analysis.GetValueDefinedAt(constant1),
1704                                      &analysis.GetValueDefinedAt(constant2)));
1705   }
1706 }
1707 
TEST_P(HloDataflowAnalysisTest,ConditionalTakingTupleOperand)1708 TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) {
1709   // Test conditional with true and false computations taking a tuple operand.
1710   //
1711   // true_computation((F32[], F32[]) %true_param):
1712   //   %true_x = GetTupleElement(%true_param, 0)
1713   //   %true_y = GetTupleElement(%true_param, 1)
1714   //   return Add(%true_x, %true_y)
1715   //
1716   // false_computation((F32[], F32[]) %false_param):
1717   //   %false_x = GetTupleElement(%false_param, 0)
1718   //   %false_y = GetTupleElement(%false_param, 1)
1719   //   return Subtract(%false_x, %false_y)
1720   //
1721   // entry:
1722   //   %pred = Constant(true)
1723   //   %constant1 = Constant(56.0)
1724   //   %constant2 = Constant(12.0)
1725   //   %tuple_operand = Tuple(%constant1, %constant2)
1726   //   return Conditional(%pred, %tuple_operand, true_computation,
1727   //                      %tuple_operand, false_computation)
1728 
1729   auto true_builder = HloComputation::Builder(TestName() + "_true");
1730   auto true_param = true_builder.AddInstruction(
1731       HloInstruction::CreateParameter(0, tuple_shape_, "true_param"));
1732   auto true_x = true_builder.AddInstruction(
1733       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0));
1734   auto true_y = true_builder.AddInstruction(
1735       HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1));
1736   auto add = true_builder.AddInstruction(HloInstruction::CreateBinary(
1737       scalar_shape_, HloOpcode::kAdd, true_x, true_y));
1738   HloComputation* true_computation =
1739       module_->AddEmbeddedComputation(true_builder.Build());
1740 
1741   auto false_builder = HloComputation::Builder(TestName() + "_false");
1742   auto false_param = false_builder.AddInstruction(
1743       HloInstruction::CreateParameter(0, tuple_shape_, "false_param"));
1744   auto false_x = false_builder.AddInstruction(
1745       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0));
1746   auto false_y = false_builder.AddInstruction(
1747       HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1));
1748   auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary(
1749       scalar_shape_, HloOpcode::kSubtract, false_x, false_y));
1750   HloComputation* false_computation =
1751       module_->AddEmbeddedComputation(false_builder.Build());
1752 
1753   auto builder = HloComputation::Builder(TestName());
1754   auto pred = builder.AddInstruction(
1755       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1756   auto constant1 = builder.AddInstruction(
1757       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(56.0f)));
1758   auto constant2 = builder.AddInstruction(
1759       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(12.0f)));
1760   auto tuple_operand = builder.AddInstruction(
1761       HloInstruction::CreateTuple({constant1, constant2}));
1762   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1763       scalar_shape_, pred, tuple_operand, true_computation, tuple_operand,
1764       false_computation));
1765   module_->AddEntryComputation(builder.Build());
1766   SCOPED_TRACE(module_->ToString());
1767 
1768   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1769 
1770   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred));
1771   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1772   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1773   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1774   EXPECT_TRUE(analysis.ValueIsDefinedAt(add));
1775   EXPECT_TRUE(analysis.ValueIsDefinedAt(sub));
1776 
1777   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param));
1778   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param));
1779   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x));
1780   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y));
1781   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x));
1782   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y));
1783 
1784   EXPECT_EQ(analysis.GetUniqueValueAt(true_param),
1785             analysis.GetValueDefinedAt(tuple_operand));
1786   EXPECT_EQ(analysis.GetUniqueValueAt(false_param),
1787             analysis.GetValueDefinedAt(tuple_operand));
1788   EXPECT_EQ(analysis.GetUniqueValueAt(true_x),
1789             analysis.GetValueDefinedAt(constant1));
1790   EXPECT_EQ(analysis.GetUniqueValueAt(true_y),
1791             analysis.GetValueDefinedAt(constant2));
1792   EXPECT_EQ(analysis.GetUniqueValueAt(false_x),
1793             analysis.GetValueDefinedAt(constant1));
1794   EXPECT_EQ(analysis.GetUniqueValueAt(false_y),
1795             analysis.GetValueDefinedAt(constant2));
1796 
1797   EXPECT_THAT(analysis.GetValueDefinedAt(pred).GetUses(),
1798               ElementsAre(HloUse{conditional, 0, {}}));
1799   EXPECT_THAT(analysis.GetValueDefinedAt(constant1).GetUses(),
1800               UnorderedElementsAre(HloUse{conditional, 1, {0}},
1801                                    HloUse{conditional, 2, {0}},
1802                                    HloUse{add, 0, {}}, HloUse{sub, 0, {}}));
1803   EXPECT_THAT(analysis.GetValueDefinedAt(constant2).GetUses(),
1804               UnorderedElementsAre(HloUse{conditional, 1, {1}},
1805                                    HloUse{conditional, 2, {1}},
1806                                    HloUse{add, 1, {}}, HloUse{sub, 1, {}}));
1807   EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).GetUses(),
1808               UnorderedElementsAre(
1809                   HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}},
1810                   HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}},
1811                   HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}}));
1812 
1813   bool ssa_form = GetParam();
1814   if (ssa_form) {
1815     EXPECT_EQ(analysis.values().size(), 7);
1816     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1817   } else {
1818     EXPECT_EQ(analysis.values().size(), 6);
1819     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1820     EXPECT_THAT(HloValuesAt(conditional),
1821                 UnorderedElementsAre(&analysis.GetValueDefinedAt(add),
1822                                      &analysis.GetValueDefinedAt(sub)));
1823   }
1824 }
1825 
TEST_P(HloDataflowAnalysisTest,NestedConditionals)1826 TEST_P(HloDataflowAnalysisTest, NestedConditionals) {
1827   // computation1(F32[] %param1):
1828   //   %ceil = Ceil(%param1)
1829   //   return %ceil
1830   //
1831   // computation2(F32[] %param2):
1832   //   %floor = Floor(%param2)
1833   //   return %floor
1834   //
1835   // computation3(F32[] %param3):
1836   //   %negate = Negate(%param3)
1837   //   return %negate
1838   //
1839   // inner_conditional((PRED, F32[], F32[]) %param_cond):
1840   //   %pred_cond = GetTupleElement(%param_cond, 0)
1841   //   %true_operand_cond = GetTupleElement(%param_cond, 1)
1842   //   %false_operand_cond = GetTupleElement(%param_cond, 2)
1843   //   return Conditional(%pred_cond, %true_operand_cond, computation1,
1844   //                      %false_operand_cond, computation2)
1845   //
1846   // entry:
1847   //   %pred1 = Constant(true)
1848   //   %pred2 = Constant(false)
1849   //   %constant1 = Constant(1.1);
1850   //   %constant2 = Constant(2.2);
1851   //   %constant3 = Constant(3.3);
1852   //   return Conditional(%pred1, (%pred2, %constant1, %constant2),
1853   //                      inner_conditional, %constant3, computation3)
1854 
1855   auto computation1 = module_->AddEmbeddedComputation(
1856       CreateR0F32UnaryOpComputation(HloOpcode::kCeil));
1857   auto computation2 = module_->AddEmbeddedComputation(
1858       CreateR0F32UnaryOpComputation(HloOpcode::kFloor));
1859   auto computation3 = module_->AddEmbeddedComputation(
1860       CreateR0F32UnaryOpComputation(HloOpcode::kNegate));
1861 
1862   // Build inner_conditional computation.
1863   const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {});
1864   const Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1865       {scalar_bool_shape, scalar_shape_, scalar_shape_});
1866   auto inner_builder =
1867       HloComputation::Builder(TestName() + "_inner_conditional");
1868   auto param_cond = inner_builder.AddInstruction(
1869       HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond"));
1870   auto pred_cond = inner_builder.AddInstruction(
1871       HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0));
1872   auto true_operand_cond = inner_builder.AddInstruction(
1873       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1));
1874   auto false_operand_cond = inner_builder.AddInstruction(
1875       HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2));
1876   auto inner_conditional =
1877       inner_builder.AddInstruction(HloInstruction::CreateConditional(
1878           scalar_shape_, pred_cond, true_operand_cond, computation1,
1879           false_operand_cond, computation2));
1880   auto inner_conditional_computation =
1881       module_->AddEmbeddedComputation(inner_builder.Build());
1882 
1883   // Build entry computation.
1884   auto builder = HloComputation::Builder(TestName());
1885   auto pred1 = builder.AddInstruction(
1886       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1887   auto pred2 = builder.AddInstruction(
1888       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
1889   auto constant1 = builder.AddInstruction(
1890       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
1891   auto constant2 = builder.AddInstruction(
1892       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.2f)));
1893   auto constant3 = builder.AddInstruction(
1894       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.3f)));
1895   auto tuple_operand = builder.AddInstruction(
1896       HloInstruction::CreateTuple({pred2, constant1, constant2}));
1897   auto conditional = builder.AddInstruction(HloInstruction::CreateConditional(
1898       scalar_shape_, pred1, tuple_operand, inner_conditional_computation,
1899       constant3, computation3));
1900   module_->AddEntryComputation(builder.Build());
1901   SCOPED_TRACE(module_->ToString());
1902 
1903   const HloDataflowAnalysis& analysis = RunAnalysis(GetParam());
1904 
1905   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1));
1906   EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2));
1907   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1));
1908   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2));
1909   EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3));
1910   EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand));
1911   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction()));
1912   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction()));
1913   EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction()));
1914 
1915   auto computation1_param = computation1->parameter_instruction(0);
1916   auto computation2_param = computation2->parameter_instruction(0);
1917   auto computation3_param = computation3->parameter_instruction(0);
1918   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param));
1919   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param));
1920   EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param));
1921   EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param),
1922             analysis.GetValueDefinedAt(constant1));
1923   EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param),
1924             analysis.GetValueDefinedAt(constant2));
1925   EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param),
1926             analysis.GetValueDefinedAt(constant3));
1927 
1928   EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond));
1929   EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond));
1930   EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond));
1931   EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond));
1932   EXPECT_EQ(analysis.GetUniqueValueAt(param_cond),
1933             analysis.GetValueDefinedAt(tuple_operand));
1934   EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond),
1935             analysis.GetValueDefinedAt(pred2));
1936   EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond),
1937             analysis.GetValueDefinedAt(constant1));
1938   EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond),
1939             analysis.GetValueDefinedAt(constant2));
1940 
1941   bool ssa_form = GetParam();
1942   if (ssa_form) {
1943     EXPECT_EQ(analysis.values().size(), 11);
1944     EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_conditional));
1945     EXPECT_TRUE(analysis.ValueIsDefinedAt(conditional));
1946   } else {
1947     EXPECT_EQ(analysis.values().size(), 9);
1948     EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional));
1949     EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional));
1950     EXPECT_THAT(
1951         HloValuesAt(inner_conditional),
1952         UnorderedElementsAre(
1953             &analysis.GetValueDefinedAt(computation1->root_instruction()),
1954             &analysis.GetValueDefinedAt(computation2->root_instruction())));
1955     EXPECT_THAT(
1956         HloValuesAt(conditional),
1957         UnorderedElementsAre(
1958             &analysis.GetValueDefinedAt(computation1->root_instruction()),
1959             &analysis.GetValueDefinedAt(computation2->root_instruction()),
1960             &analysis.GetValueDefinedAt(computation3->root_instruction())));
1961   }
1962 }
1963 
TEST_P(HloDataflowAnalysisTest,AddDependency)1964 TEST_P(HloDataflowAnalysisTest, AddDependency) {
1965   std::string module_string = R"(
1966 HloModule AddDependency
1967 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1968   %p = f32[3] parameter(0)
1969   %token0 = token[] after-all()
1970   ROOT %add_dep = f32[3] add-dependency(f32[3] %p, token[] %token0)
1971 }
1972 )";
1973   TF_ASSERT_OK_AND_ASSIGN(
1974       std::unique_ptr<HloModule> module,
1975       ParseAndReturnVerifiedModule(module_string, GetModuleConfigForTest()));
1976 
1977   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
1978                           HloDataflowAnalysis::Run(*module));
1979   const HloInstruction* root = module->entry_computation()->root_instruction();
1980   EXPECT_EQ(root->opcode(), HloOpcode::kAddDependency);
1981 
1982   // The after-all and parameter should define a value. Add-dependency should
1983   // not.
1984   EXPECT_EQ(analysis->values().size(), 2);
1985   EXPECT_FALSE(analysis->ValueIsDefinedAt(root));
1986 }
1987 
TEST_F(HloDataflowAnalysisTest,AllReduceStartAndDone)1988 TEST_F(HloDataflowAnalysisTest, AllReduceStartAndDone) {
1989   const char* hlo_text = R"(
1990     HloModule test
1991     add {
1992       x = f32[] parameter(0)
1993       y = f32[] parameter(1)
1994       ROOT add = f32[] add(x, y)
1995     }
1996     ENTRY entry {
1997       p0 = f32[2] parameter(0)
1998       start = f32[2] all-reduce-start(p0), to_apply=add
1999       ROOT done = f32[2] all-reduce-done(start)
2000     }
2001   )";
2002   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2003                           ParseAndReturnVerifiedModule(hlo_text));
2004 
2005   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
2006                           HloDataflowAnalysis::Run(*module));
2007 
2008   HloInstruction* done = module->entry_computation()->root_instruction();
2009   HloInstruction* start = done->mutable_operand(0);
2010   HloInstruction* param0 = start->mutable_operand(0);
2011 
2012   EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{}));
2013   EXPECT_FALSE(analysis->ValueIsDefinedAt(done));
2014 
2015   EXPECT_THAT(analysis->GetValueDefinedAt(param0).GetUses(),
2016               UnorderedElementsAre(HloUse{start, 0, {}}));
2017   EXPECT_THAT(analysis->GetValueDefinedAt(start).GetUses(),
2018               UnorderedElementsAre(HloUse{done, 0, {}}));
2019 }
2020 
TEST_F(HloDataflowAnalysisTest,AllReduceStartAndDoneTwoOperands)2021 TEST_F(HloDataflowAnalysisTest, AllReduceStartAndDoneTwoOperands) {
2022   const char* hlo_text = R"(
2023     HloModule test
2024     add {
2025       x = f32[] parameter(0)
2026       y = f32[] parameter(1)
2027       ROOT add = f32[] add(x, y)
2028     }
2029     ENTRY entry {
2030       p0 = f32[2] parameter(0)
2031       p1 = f32[2] parameter(1)
2032       start = (f32[2], f32[2]) all-reduce-start(p0, p1), to_apply=add
2033       ROOT done = (f32[2], f32[2]) all-reduce-done(start)
2034     }
2035   )";
2036   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2037                           ParseAndReturnVerifiedModule(hlo_text));
2038 
2039   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloDataflowAnalysis> analysis,
2040                           HloDataflowAnalysis::Run(*module));
2041 
2042   HloInstruction* done = module->entry_computation()->root_instruction();
2043   HloInstruction* start = done->mutable_operand(0);
2044   HloInstruction* param0 = start->mutable_operand(0);
2045   HloInstruction* param1 = start->mutable_operand(1);
2046 
2047   EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{}));
2048   EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{0}));
2049   EXPECT_TRUE(analysis->ValueIsDefinedAt(start, /*index=*/{1}));
2050   EXPECT_FALSE(analysis->ValueIsDefinedAt(done));
2051 
2052   EXPECT_THAT(analysis->GetValueDefinedAt(param0).GetUses(),
2053               UnorderedElementsAre(HloUse{start, 0, {}}));
2054   EXPECT_THAT(analysis->GetValueDefinedAt(param1).GetUses(),
2055               UnorderedElementsAre(HloUse{start, 1, {}}));
2056   EXPECT_THAT(analysis->GetValueDefinedAt(start, {}).GetUses(),
2057               UnorderedElementsAre(HloUse{done, 0, {}}));
2058 }
2059 
2060 INSTANTIATE_TEST_SUITE_P(HloDataflowAnalysisInstantiation,
2061                          HloDataflowAnalysisTest,
2062                          ::testing::Values(false, true));
2063 
RunAnalysis(const HloModule & module,const HloDataflowAnalysis::CanShareBuffer & can_share_buffer=nullptr)2064 std::unique_ptr<HloDataflowAnalysis> RunAnalysis(
2065     const HloModule& module,
2066     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
2067   return HloDataflowAnalysis::Run(module, /*ssa_form=*/false,
2068                                   /*bitcast_defines_value=*/false,
2069                                   can_share_buffer)
2070       .value();
2071 }
2072 
2073 using DoesNotUseOperandBufferTest = HloTestBase;
2074 
TEST_F(DoesNotUseOperandBufferTest,GetTupleElement)2075 TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) {
2076   auto builder = HloComputation::Builder(TestName());
2077 
2078   Shape elem_shape = ShapeUtil::MakeShape(F32, {8});
2079   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2080       0, ShapeUtil::MakeTupleShape({elem_shape, elem_shape}), "tuple"));
2081   auto gte0 = builder.AddInstruction(
2082       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 0));
2083   auto gte1 = builder.AddInstruction(
2084       HloInstruction::CreateGetTupleElement(elem_shape, tuple, 1));
2085   builder.AddInstruction(
2086       HloInstruction::CreateBinary(elem_shape, HloOpcode::kAdd, gte0, gte1));
2087 
2088   auto module = CreateNewVerifiedModule();
2089   module->AddEntryComputation(builder.Build());
2090   auto dataflow_analysis = RunAnalysis(*module);
2091 
2092   // GetTupleElement instructions only access the top-level buffer of their
2093   // operand.
2094   EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {0}, gte0));
2095   EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {1}, gte1));
2096   EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {}, gte0));
2097   EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {}, gte1));
2098 }
2099 
TEST_F(DoesNotUseOperandBufferTest,FusedDynamicUpdateSlice)2100 TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) {
2101   auto builder = HloComputation::Builder(TestName());
2102 
2103   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2104   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2105       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2106   auto gte0 = builder.AddInstruction(
2107       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2108   auto gte1 = builder.AddInstruction(
2109       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2110 
2111   // Create a DynamicUpdateSlice instruction of tuple element 1.
2112   auto starts = builder.AddInstruction(
2113       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2114   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2115       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2116   auto dynamic_update_slice =
2117       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2118           data_shape, gte1, update,
2119           std::initializer_list<HloInstruction*>({starts})));
2120   builder.AddInstruction(
2121       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2122 
2123   auto module = CreateNewVerifiedModule();
2124   auto computation = module->AddEntryComputation(builder.Build());
2125   auto fusion = computation->CreateFusionInstruction(
2126       {dynamic_update_slice, starts, update, gte1},
2127       HloInstruction::FusionKind::kLoop);
2128   auto dataflow_analysis = RunAnalysis(*module);
2129 
2130   // The fusion instruction never uses tuple element 0, but does use element 1.
2131   EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2132   EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2133 }
2134 
2135 // Similar to FusedDynamicUpdateSlice above, but tests indirect uses of the
2136 // parameter tuple.
TEST_F(DoesNotUseOperandBufferTest,IndirectUses)2137 TEST_F(DoesNotUseOperandBufferTest, IndirectUses) {
2138   auto builder = HloComputation::Builder(TestName());
2139 
2140   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2141   auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
2142       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2143   auto t0 = builder.AddInstruction(
2144       HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 0));
2145   auto t1 = builder.AddInstruction(
2146       HloInstruction::CreateGetTupleElement(data_shape, tuple_param, 1));
2147   // Swap the tuple elements.
2148   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({t1, t0}));
2149 
2150   auto gte0 = builder.AddInstruction(
2151       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2152   auto gte1 = builder.AddInstruction(
2153       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2154 
2155   // Create a DynamicUpdateSlice instruction of tuple element 1.
2156   auto starts = builder.AddInstruction(
2157       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2158   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2159       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2160   auto dynamic_update_slice =
2161       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2162           data_shape, gte1, update,
2163           std::initializer_list<HloInstruction*>({starts})));
2164   builder.AddInstruction(
2165       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2166 
2167   auto module = CreateNewVerifiedModule();
2168   auto computation = module->AddEntryComputation(builder.Build());
2169   auto fusion = computation->CreateFusionInstruction(
2170       {dynamic_update_slice, starts, update, gte1},
2171       HloInstruction::FusionKind::kLoop);
2172   auto dataflow_analysis = RunAnalysis(*module);
2173 
2174   // The fusion instruction never uses tuple element 0, but does use element 1.
2175   EXPECT_TRUE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {0}, fusion));
2176   EXPECT_FALSE(dataflow_analysis->DoesNotUseOperandBuffer(tuple, {1}, fusion));
2177   // The same holds for the parameter tuple, except that the tuple elements
2178   // are swapped in 'tuple'.
2179   EXPECT_TRUE(
2180       dataflow_analysis->DoesNotUseOperandBuffer(tuple_param, {1}, fusion));
2181   EXPECT_FALSE(
2182       dataflow_analysis->DoesNotUseOperandBuffer(tuple_param, {0}, fusion));
2183 }
2184 
2185 using CanShareOperandBufferWithUserTest = HloTestBase;
2186 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseSameShape)2187 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) {
2188   auto builder = HloComputation::Builder(TestName());
2189 
2190   Shape shape = ShapeUtil::MakeShape(F32, {8});
2191   auto param = builder.AddInstruction(
2192       HloInstruction::CreateParameter(0, shape, "param"));
2193   auto exp = builder.AddInstruction(
2194       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2195   auto log = builder.AddInstruction(
2196       HloInstruction::CreateUnary(shape, HloOpcode::kLog, exp));
2197 
2198   auto module = CreateNewVerifiedModule();
2199   module->AddEntryComputation(builder.Build());
2200   auto dataflow_analysis = RunAnalysis(*module);
2201 
2202   EXPECT_TRUE(
2203       dataflow_analysis->CanShareOperandBufferWithUser(param, {}, exp, {}));
2204   EXPECT_TRUE(
2205       dataflow_analysis->CanShareOperandBufferWithUser(exp, {}, log, {}));
2206 }
2207 
TEST_F(CanShareOperandBufferWithUserTest,NonElementwiseLoopFusionCantAliasOperandBuffer)2208 TEST_F(CanShareOperandBufferWithUserTest,
2209        NonElementwiseLoopFusionCantAliasOperandBuffer) {
2210   auto builder = HloComputation::Builder(TestName());
2211   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2212 
2213   auto param0 = builder.AddInstruction(
2214       HloInstruction::CreateParameter(0, data_shape, "param0"));
2215 
2216   auto neg = builder.AddInstruction(
2217       HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, param0));
2218 
2219   auto reverse = builder.AddInstruction(
2220       HloInstruction::CreateReverse(data_shape, neg, {0, 1}));
2221 
2222   auto module = CreateNewVerifiedModule();
2223   auto computation = module->AddEntryComputation(builder.Build());
2224   auto fusion = computation->CreateFusionInstruction(
2225       {reverse, neg}, HloInstruction::FusionKind::kLoop);
2226   auto dataflow_analysis = RunAnalysis(*module);
2227 
2228   EXPECT_FALSE(
2229       dataflow_analysis->CanShareOperandBufferWithUser(param0, {}, fusion, {}));
2230 }
2231 
TEST_F(CanShareOperandBufferWithUserTest,MultiOutputFusionCanAliasOperandBuffer)2232 TEST_F(CanShareOperandBufferWithUserTest,
2233        MultiOutputFusionCanAliasOperandBuffer) {
2234   auto builder = HloComputation::Builder(TestName());
2235   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2236 
2237   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2238   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2239   auto param0 = builder.AddInstruction(
2240       HloInstruction::CreateParameter(0, in_shape, "param0"));
2241   auto param1 = builder.AddInstruction(
2242       HloInstruction::CreateParameter(1, in_shape, "param1"));
2243 
2244   auto copy0 = builder.AddInstruction(
2245       HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param0));
2246   auto copy1 = builder.AddInstruction(
2247       HloInstruction::CreateUnary(in_shape, HloOpcode::kCopy, param1));
2248 
2249   auto tuple =
2250       builder.AddInstruction(HloInstruction::CreateTuple({copy1, copy0}));
2251 
2252   auto module = CreateNewVerifiedModule();
2253   auto computation = module->AddEntryComputation(builder.Build());
2254   auto fusion = computation->CreateFusionInstruction(
2255       {tuple, copy1, copy0}, HloInstruction::FusionKind::kLoop);
2256   auto dataflow_analysis = RunAnalysis(*module);
2257 
2258   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2259                                                                fusion, {0}));
2260   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2261                                                                fusion, {1}));
2262   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2263                                                                fusion, {0}));
2264   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2265                                                                fusion, {1}));
2266 }
2267 
TEST_F(CanShareOperandBufferWithUserTest,ElementwiseLoopFusionCantAliasOperandBuffer)2268 TEST_F(CanShareOperandBufferWithUserTest,
2269        ElementwiseLoopFusionCantAliasOperandBuffer) {
2270   auto builder = HloComputation::Builder(TestName());
2271   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2272 
2273   auto one = builder.AddInstruction(
2274       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2275   auto operand = builder.AddInstruction(
2276       HloInstruction::CreateBroadcast(data_shape, one, {}));
2277 
2278   auto neg = builder.AddInstruction(
2279       HloInstruction::CreateUnary(data_shape, HloOpcode::kNegate, operand));
2280 
2281   auto exp = builder.AddInstruction(
2282       HloInstruction::CreateUnary(data_shape, HloOpcode::kExp, neg));
2283 
2284   auto module = CreateNewVerifiedModule();
2285   auto computation = module->AddEntryComputation(builder.Build());
2286   auto fusion = computation->CreateFusionInstruction(
2287       {exp, neg}, HloInstruction::FusionKind::kLoop);
2288   auto dataflow_analysis = RunAnalysis(*module);
2289 
2290   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(operand, {},
2291                                                                fusion, {}));
2292 }
2293 
TEST_F(CanShareOperandBufferWithUserTest,CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex)2294 TEST_F(CanShareOperandBufferWithUserTest,
2295        CanShareOperandWhenDynamicUpdateSliceIsFedByDynamicSliceWithSameIndex) {
2296   auto builder = HloComputation::Builder(TestName());
2297   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2298   Shape slice_shape = ShapeUtil::MakeShape(F32, {1, 2});
2299 
2300   auto param = builder.AddInstruction(
2301       HloInstruction::CreateParameter(0, data_shape, "param0"));
2302   auto zero = builder.AddInstruction(
2303       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64_t>(0)));
2304   auto ds = builder.AddInstruction(HloInstruction::CreateDynamicSlice(
2305       slice_shape, param, {zero, zero}, {1, 2}));
2306 
2307   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2308       data_shape, param, ds, {zero, zero}));
2309 
2310   auto module = CreateNewVerifiedModule();
2311   auto computation = module->AddEntryComputation(builder.Build());
2312   auto fusion = computation->CreateFusionInstruction(
2313       {dus, ds, zero}, HloInstruction::FusionKind::kLoop);
2314   auto dataflow_analysis = RunAnalysis(*module);
2315 
2316   EXPECT_TRUE(
2317       dataflow_analysis->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2318 }
2319 
TEST_F(CanShareOperandBufferWithUserTest,DUSWithSliceWithSameIndices)2320 TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
2321   const char* kModule = R"(
2322     HloModule test
2323 
2324     fused_computation {
2325       p0 = f32[10,20,30] parameter(0)
2326       p1 = s32[] parameter(1)
2327       p2 = s32[] parameter(2)
2328       p3 = s32[] parameter(3)
2329       slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
2330       ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p2, p3)
2331     }
2332 
2333     ENTRY test {
2334       p0 = f32[10,20,30] parameter(0)
2335       p1 = s32[] parameter(1)
2336       p2 = s32[] parameter(2)
2337       p3 = s32[] parameter(3)
2338       ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
2339     }
2340   )";
2341   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2342   auto* fusion = module->entry_computation()->root_instruction();
2343   auto* param = module->entry_computation()->parameter_instruction(0);
2344 
2345   auto dataflow_analysis = RunAnalysis(*module);
2346   EXPECT_TRUE(
2347       dataflow_analysis->CanShareOperandBufferWithUser(param, {}, fusion, {}));
2348 }
2349 
TEST_F(CanShareOperandBufferWithUserTest,ElementWiseDifferentShape)2350 TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) {
2351   auto builder = HloComputation::Builder(TestName());
2352 
2353   Shape in_shape = ShapeUtil::MakeShape(F32, {8});
2354   Shape out_shape = ShapeUtil::MakeShape(PRED, {8});
2355   auto param0 = builder.AddInstruction(
2356       HloInstruction::CreateParameter(0, in_shape, "param0"));
2357   auto param1 = builder.AddInstruction(
2358       HloInstruction::CreateParameter(1, in_shape, "param1"));
2359   auto result = builder.AddInstruction(HloInstruction::CreateCompare(
2360       out_shape, param0, param1, ComparisonDirection::kEq));
2361 
2362   auto module = CreateNewVerifiedModule();
2363   module->AddEntryComputation(builder.Build());
2364   auto dataflow_analysis = RunAnalysis(*module);
2365 
2366   EXPECT_FALSE(
2367       dataflow_analysis->CanShareOperandBufferWithUser(param0, {}, result, {}));
2368   EXPECT_FALSE(
2369       dataflow_analysis->CanShareOperandBufferWithUser(param1, {}, result, {}));
2370 }
2371 
TEST_F(CanShareOperandBufferWithUserTest,CopyShares)2372 TEST_F(CanShareOperandBufferWithUserTest, CopyShares) {
2373   auto builder = HloComputation::Builder(TestName());
2374 
2375   Shape shape = ShapeUtil::MakeShape(F32, {8});
2376   auto param = builder.AddInstruction(
2377       HloInstruction::CreateParameter(0, shape, "param"));
2378   auto exp = builder.AddInstruction(
2379       HloInstruction::CreateUnary(shape, HloOpcode::kExp, param));
2380   auto copy = builder.AddInstruction(
2381       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, exp));
2382 
2383   auto module = CreateNewVerifiedModule();
2384   module->AddEntryComputation(builder.Build());
2385   auto dataflow_analysis = RunAnalysis(*module);
2386 
2387   EXPECT_TRUE(
2388       dataflow_analysis->CanShareOperandBufferWithUser(param, {}, exp, {}));
2389   EXPECT_TRUE(
2390       dataflow_analysis->CanShareOperandBufferWithUser(exp, {}, copy, {}));
2391 }
2392 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSlice)2393 TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) {
2394   auto builder = HloComputation::Builder(TestName());
2395 
2396   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2397   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2398       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2399   auto gte0 = builder.AddInstruction(
2400       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2401   auto gte1 = builder.AddInstruction(
2402       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2403 
2404   // Create a DynamicUpdateSlice instruction of tuple element 1.
2405   auto starts = builder.AddInstruction(
2406       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2407   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2408       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2409   auto dynamic_update_slice =
2410       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2411           data_shape, gte1, update,
2412           std::initializer_list<HloInstruction*>({starts})));
2413   builder.AddInstruction(
2414       HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
2415 
2416   auto module = CreateNewVerifiedModule();
2417   auto computation = module->AddEntryComputation(builder.Build());
2418   auto fusion = computation->CreateFusionInstruction(
2419       {dynamic_update_slice, starts, update, gte1},
2420       HloInstruction::FusionKind::kLoop);
2421   auto dataflow_analysis = RunAnalysis(*module);
2422 
2423   // The fusion instruction can share with tuple element 1.
2424   EXPECT_FALSE(
2425       dataflow_analysis->CanShareOperandBufferWithUser(tuple, {0}, fusion, {}));
2426   EXPECT_TRUE(
2427       dataflow_analysis->CanShareOperandBufferWithUser(tuple, {1}, fusion, {}));
2428 }
2429 
TEST_F(CanShareOperandBufferWithUserTest,FusedDynamicUpdateSliceWithConvertCanShare)2430 TEST_F(CanShareOperandBufferWithUserTest,
2431        FusedDynamicUpdateSliceWithConvertCanShare) {
2432   auto builder = HloComputation::Builder(TestName());
2433 
2434   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2435   Shape data_shape_bf16 = ShapeUtil::MakeShape(BF16, {8});
2436   auto tuple = builder.AddInstruction(HloInstruction::CreateParameter(
2437       0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "tuple"));
2438   auto gte0 = builder.AddInstruction(
2439       HloInstruction::CreateGetTupleElement(data_shape, tuple, 0));
2440   auto gte1 = builder.AddInstruction(
2441       HloInstruction::CreateGetTupleElement(data_shape, tuple, 1));
2442 
2443   auto convert1 = builder.AddInstruction(
2444       HloInstruction::CreateConvert(data_shape_bf16, gte1));
2445 
2446   // Create a DynamicUpdateSlice instruction of tuple element 1.
2447   auto starts = builder.AddInstruction(
2448       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32_t>(2)));
2449   auto update = builder.AddInstruction(HloInstruction::CreateConstant(
2450       LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
2451   auto dynamic_update_slice =
2452       builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2453           data_shape_bf16, convert1, update,
2454           std::initializer_list<HloInstruction*>({starts})));
2455 
2456   auto convert2 = builder.AddInstruction(
2457       HloInstruction::CreateConvert(data_shape, dynamic_update_slice));
2458   builder.AddInstruction(HloInstruction::CreateTuple({gte0, convert2}));
2459 
2460   auto module = CreateNewVerifiedModule();
2461   auto computation = module->AddEntryComputation(builder.Build());
2462   auto fusion = computation->CreateFusionInstruction(
2463       {convert2, dynamic_update_slice, starts, update, convert1},
2464       HloInstruction::FusionKind::kLoop);
2465   auto dataflow_analysis = RunAnalysis(*module);
2466 
2467   EXPECT_TRUE(
2468       dataflow_analysis->CanShareOperandBufferWithUser(gte1, {}, fusion, {}));
2469 }
2470 
TEST_F(CanShareOperandBufferWithUserTest,DynamicUpdateSliceCanShare)2471 TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) {
2472   auto builder = HloComputation::Builder(TestName());
2473 
2474   Shape data_shape = ShapeUtil::MakeShape(F32, {1, 8});
2475   Shape update_shape = ShapeUtil::MakeShape(F32, {1, 4});
2476   Shape starts_shape = ShapeUtil::MakeShape(S32, {2});
2477   auto data = builder.AddInstruction(
2478       HloInstruction::CreateParameter(0, data_shape, "data"));
2479   auto update = builder.AddInstruction(
2480       HloInstruction::CreateParameter(1, update_shape, "update"));
2481   auto start = builder.AddInstruction(
2482       HloInstruction::CreateParameter(2, starts_shape, "start"));
2483 
2484   auto dus = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
2485       data_shape, data, update, {start}));
2486 
2487   auto module = CreateNewVerifiedModule();
2488   module->AddEntryComputation(builder.Build());
2489   auto dataflow_analysis = RunAnalysis(*module);
2490 
2491   // The DynamicUpdateSlice instruction can share with the data operand, but not
2492   // with update or start.
2493   EXPECT_TRUE(
2494       dataflow_analysis->CanShareOperandBufferWithUser(data, {}, dus, {}));
2495   EXPECT_FALSE(
2496       dataflow_analysis->CanShareOperandBufferWithUser(update, {}, dus, {}));
2497   EXPECT_FALSE(
2498       dataflow_analysis->CanShareOperandBufferWithUser(start, {}, dus, {}));
2499 }
2500 
TEST_F(CanShareOperandBufferWithUserTest,ScatterCanShare)2501 TEST_F(CanShareOperandBufferWithUserTest, ScatterCanShare) {
2502   const char* hlo_text = R"(
2503     HloModule TensorFlowScatterV1
2504 
2505     update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
2506       lhs = s32[] parameter(0)
2507       ROOT rhs = s32[] parameter(1)
2508     }
2509 
2510     ENTRY main {
2511       operand = s32[3,3] parameter(0)
2512       indices = s32[2] parameter(1)
2513       updates = s32[2,3] parameter(2)
2514       ROOT scatter = s32[3,3] scatter(operand, indices, updates),
2515           to_apply=update_s32,
2516           update_window_dims={1},
2517           inserted_window_dims={0},
2518           scatter_dims_to_operand_dims={0},
2519           index_vector_dim=1
2520     }
2521   )";
2522   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
2523   auto computation = module->entry_computation();
2524   auto dataflow_analysis = RunAnalysis(*module);
2525 
2526   HloInstruction* operand_param = computation->parameter_instruction(0);
2527   HloInstruction* indices_param = computation->parameter_instruction(1);
2528   HloInstruction* updates_param = computation->parameter_instruction(2);
2529   HloInstruction* scatter = computation->root_instruction();
2530 
2531   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2532       operand_param, {}, scatter, {}));
2533   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2534       indices_param, {}, scatter, {}));
2535   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2536       updates_param, {}, scatter, {}));
2537 }
2538 
TEST_F(CanShareOperandBufferWithUserTest,MultioutputScatterCanShare)2539 TEST_F(CanShareOperandBufferWithUserTest, MultioutputScatterCanShare) {
2540   const char* hlo_text = R"(
2541     HloModule MultioutputScatter
2542 
2543     update {
2544       lhs0 = s32[] parameter(0)
2545       lhs1 = f32[] parameter(1)
2546       rhs0 = s32[] parameter(2)
2547       rhs1 = f32[] parameter(3)
2548       ROOT tuple = tuple(rhs0, rhs1)
2549     }
2550 
2551     ENTRY main {
2552       operand0 = s32[3,3] parameter(0)
2553       operand1 = f32[3,3] parameter(1)
2554       indices = s32[2] parameter(2)
2555       updates0 = s32[2,3] parameter(3)
2556       updates1 = f32[2,3] parameter(4)
2557       ROOT scatter = (s32[3,3], f32[3,3])
2558       scatter(operand0, operand1, indices, updates0, updates1),
2559           to_apply=update,
2560           update_window_dims={1},
2561           inserted_window_dims={0},
2562           scatter_dims_to_operand_dims={0},
2563           index_vector_dim=1
2564     }
2565   )";
2566   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
2567   auto computation = module->entry_computation();
2568   auto dataflow_analysis = RunAnalysis(*module);
2569 
2570   HloInstruction* operand0_param = computation->parameter_instruction(0);
2571   HloInstruction* operand1_param = computation->parameter_instruction(1);
2572   HloInstruction* indices_param = computation->parameter_instruction(2);
2573   HloInstruction* updates0_param = computation->parameter_instruction(3);
2574   HloInstruction* updates1_param = computation->parameter_instruction(4);
2575   HloInstruction* scatter = computation->root_instruction();
2576 
2577   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2578       operand0_param, {}, scatter, {0}));
2579   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2580       operand0_param, {}, scatter, {1}));
2581   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2582       operand1_param, {}, scatter, {0}));
2583   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2584       operand1_param, {}, scatter, {1}));
2585   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2586       indices_param, {}, scatter, {0}));
2587   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2588       indices_param, {}, scatter, {1}));
2589   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2590       updates0_param, {}, scatter, {0}));
2591   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2592       updates0_param, {}, scatter, {1}));
2593   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2594       updates1_param, {}, scatter, {0}));
2595   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2596       updates1_param, {}, scatter, {1}));
2597 }
2598 
TEST_F(CanShareOperandBufferWithUserTest,TriangularSolveCanShare)2599 TEST_F(CanShareOperandBufferWithUserTest, TriangularSolveCanShare) {
2600   const char* hlo_text = R"(
2601     HloModule TensorFlowTriangularSolve
2602 
2603     ENTRY main {
2604       a = f32[4,4]{1,0} parameter(0)
2605       b = f32[3,4]{1,0} parameter(1)
2606       ROOT triangular-solve = f32[3,4]{1,0} triangular-solve(a, b), lower=true,
2607                                               transpose_a=NO_TRANSPOSE
2608     }
2609   )";
2610   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text));
2611   auto computation = module->entry_computation();
2612   auto dataflow_analysis = RunAnalysis(*module);
2613 
2614   HloInstruction* lhs_param = computation->parameter_instruction(0);
2615   HloInstruction* rhs_param = computation->parameter_instruction(1);
2616   HloInstruction* triangular_solve = computation->root_instruction();
2617 
2618   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(
2619       lhs_param, {}, triangular_solve, {}));
2620   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(
2621       rhs_param, {}, triangular_solve, {}));
2622 }
2623 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShare)2624 TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
2625   auto builder = HloComputation::Builder(TestName());
2626   auto module = CreateNewVerifiedModule();
2627 
2628   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2629   auto keys = builder.AddInstruction(
2630       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2631   TF_ASSERT_OK_AND_ASSIGN(
2632       auto* sort, MakeSortHlo(keys_shape, {keys}, -1, /*is_stable=*/false,
2633                               &builder, module.get()));
2634 
2635   module->AddEntryComputation(builder.Build());
2636   auto dataflow_analysis = RunAnalysis(*module);
2637 
2638   EXPECT_TRUE(
2639       dataflow_analysis->CanShareOperandBufferWithUser(keys, {}, sort, {}));
2640 }
2641 
TEST_F(CanShareOperandBufferWithUserTest,SortCanShareWithTupleUser)2642 TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
2643   auto builder = HloComputation::Builder(TestName());
2644   auto module = CreateNewVerifiedModule();
2645 
2646   Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
2647   Shape values_shape = ShapeUtil::MakeShape(F32, {8});
2648   auto keys = builder.AddInstruction(
2649       HloInstruction::CreateParameter(0, keys_shape, "keys"));
2650   auto values = builder.AddInstruction(
2651       HloInstruction::CreateParameter(1, values_shape, "values"));
2652   TF_ASSERT_OK_AND_ASSIGN(
2653       auto* sort,
2654       MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
2655                   {keys, values}, 0, /*is_stable=*/false, &builder,
2656                   module.get()));
2657 
2658   module->AddEntryComputation(builder.Build());
2659   auto dataflow_analysis = RunAnalysis(*module);
2660 
2661   // The buffer for the keys can be shared with the first tuple entry.
2662   EXPECT_TRUE(
2663       dataflow_analysis->CanShareOperandBufferWithUser(keys, {}, sort, {0}));
2664   // The buffer for the values can be shared with the second tuple entry.
2665   EXPECT_TRUE(
2666       dataflow_analysis->CanShareOperandBufferWithUser(values, {}, sort, {1}));
2667   // Verify that the buffers are not shared with the "wrong" tuple entry.
2668   EXPECT_FALSE(
2669       dataflow_analysis->CanShareOperandBufferWithUser(keys, {}, sort, {1}));
2670   EXPECT_FALSE(
2671       dataflow_analysis->CanShareOperandBufferWithUser(values, {}, sort, {0}));
2672 }
2673 
TEST_F(CanShareOperandBufferWithUserTest,FusedDotAdd)2674 TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
2675   auto builder = HloComputation::Builder(TestName());
2676   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2677 
2678   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
2679       LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
2680   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
2681       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2682 
2683   DotDimensionNumbers dot_dnums;
2684   dot_dnums.add_lhs_contracting_dimensions(1);
2685   dot_dnums.add_rhs_contracting_dimensions(0);
2686   PrecisionConfig precision_config;
2687   precision_config.mutable_operand_precision()->Resize(
2688       2, PrecisionConfig::DEFAULT);
2689   auto dot = builder.AddInstruction(
2690       HloInstruction::CreateDot(data_shape, a, b, dot_dnums, precision_config));
2691 
2692   auto one = builder.AddInstruction(
2693       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2694   auto add_operand = builder.AddInstruction(
2695       HloInstruction::CreateBroadcast(data_shape, one, {}));
2696 
2697   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
2698       data_shape, HloOpcode::kAdd, dot, add_operand));
2699 
2700   auto module = CreateNewVerifiedModule();
2701   auto computation = module->AddEntryComputation(builder.Build());
2702   auto fusion = computation->CreateFusionInstruction(
2703       {add, dot}, HloInstruction::FusionKind::kOutput);
2704   auto dataflow_analysis = RunAnalysis(*module);
2705 
2706   // Output fused dot add should be able to share buffer with 'add_operand'.
2707   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(add_operand, {},
2708                                                                fusion, {}));
2709 }
2710 
TEST_F(CanShareOperandBufferWithUserTest,OutputFusionCantAliasOperandBuffer)2711 TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) {
2712   auto builder = HloComputation::Builder(TestName());
2713   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2714 
2715   auto one = builder.AddInstruction(
2716       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2717   auto operand = builder.AddInstruction(
2718       HloInstruction::CreateBroadcast(data_shape, one, {}));
2719 
2720   auto reverse = builder.AddInstruction(
2721       HloInstruction::CreateReverse(data_shape, operand, {0, 1}));
2722 
2723   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2724       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2725 
2726   auto add = builder.AddInstruction(
2727       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two));
2728 
2729   auto module = CreateNewVerifiedModule();
2730   auto computation = module->AddEntryComputation(builder.Build());
2731   auto fusion = computation->CreateFusionInstruction(
2732       {add, two, reverse}, HloInstruction::FusionKind::kOutput);
2733   auto dataflow_analysis = RunAnalysis(*module);
2734 
2735   // Output fused operand->reverse->add cannot alias operand buffer 'operand'.
2736   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(operand, {},
2737                                                                 fusion, {}));
2738 }
2739 
TEST_F(CanShareOperandBufferWithUserTest,FusionCanShareBufferCustomized)2740 TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
2741   auto builder = HloComputation::Builder(TestName());
2742   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
2743 
2744   auto one = builder.AddInstruction(
2745       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2746   auto operand = builder.AddInstruction(
2747       HloInstruction::CreateBroadcast(data_shape, one, {}));
2748   auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
2749       data_shape, HloOpcode::kMultiply, operand, operand));
2750   auto two = builder.AddInstruction(HloInstruction::CreateConstant(
2751       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
2752   auto add = builder.AddInstruction(
2753       HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, mul, two));
2754 
2755   auto module = CreateNewVerifiedModule();
2756   auto computation = module->AddEntryComputation(builder.Build());
2757   auto fusion = computation->CreateFusionInstruction(
2758       {add, two, mul}, HloInstruction::FusionKind::kInput);
2759   auto dataflow_analysis = RunAnalysis(
2760       *module,
2761       /*can_share_buffer=*/[](const HloInstruction* fusion,
2762                               const HloInstruction*, const ShapeIndex&) {
2763         return fusion->IsLoopFusion();
2764       });
2765 
2766   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(operand, {},
2767                                                                 fusion, {}));
2768 }
2769 
TEST_F(CanShareOperandBufferWithUserTest,WhileCanShare)2770 TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) {
2771   auto module = CreateNewVerifiedModule();
2772   Shape data_shape = ShapeUtil::MakeShape(F32, {8});
2773   Shape pred_scalar_shape = ShapeUtil::MakeShape(PRED, {});
2774 
2775   auto b = HloComputation::Builder(TestName() + ".And");
2776   auto p0 = b.AddInstruction(
2777       HloInstruction::CreateParameter(0, pred_scalar_shape, "p0"));
2778   auto p1 = b.AddInstruction(
2779       HloInstruction::CreateParameter(1, pred_scalar_shape, "p1"));
2780   b.AddInstruction(
2781       HloInstruction::CreateBinary(pred_scalar_shape, HloOpcode::kAnd, p0, p1));
2782   auto and_computation = module->AddEmbeddedComputation(b.Build());
2783 
2784   auto make_cond = [&data_shape, &and_computation]() {
2785     auto builder = HloComputation::Builder(TestName() + ".Cond");
2786     auto data = builder.AddInstruction(
2787         HloInstruction::CreateParameter(0, data_shape, "data"));
2788     auto compare = builder.AddInstruction(HloInstruction::CreateCompare(
2789         ShapeUtil::MakeShape(PRED, {8}), data, data, ComparisonDirection::kEq));
2790     auto true_value = builder.AddInstruction(
2791         HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2792     builder.AddInstruction(
2793         HloInstruction::CreateReduce(ShapeUtil::MakeShape(PRED, {}), compare,
2794                                      true_value, {0}, and_computation));
2795     return builder.Build();
2796   };
2797 
2798   auto make_body = [&data_shape]() {
2799     auto builder = HloComputation::Builder(TestName() + ".Body");
2800     auto data = builder.AddInstruction(
2801         HloInstruction::CreateParameter(0, data_shape, "data"));
2802     builder.AddInstruction(
2803         HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, data, data));
2804     return builder.Build();
2805   };
2806 
2807   HloComputation* cond_computation =
2808       module->AddEmbeddedComputation(make_cond());
2809   HloComputation* body_computation =
2810       module->AddEmbeddedComputation(make_body());
2811 
2812   auto builder = HloComputation::Builder(TestName());
2813   auto data = builder.AddInstruction(
2814       HloInstruction::CreateParameter(0, data_shape, "data"));
2815   auto whil = builder.AddInstruction(HloInstruction::CreateWhile(
2816       data_shape, cond_computation, body_computation, data));
2817   module->AddEntryComputation(builder.Build());
2818 
2819   auto dataflow_analysis = RunAnalysis(*module);
2820 
2821   // The While instruction can share with the data operand.
2822   EXPECT_TRUE(
2823       dataflow_analysis->CanShareOperandBufferWithUser(data, {}, whil, {}));
2824 }
2825 
2826 // Tests that Call can alias operand buffer if the only use of the operand
2827 // in the called computation is an elementwise instruction.
TEST_F(CanShareOperandBufferWithUserTest,CallToComputationWithFusionRoot)2828 TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
2829   Shape shape = ShapeUtil::MakeShape(F32, {8});
2830   // Build sub-computation with fusion root.
2831   auto sub_builder = HloComputation::Builder(TestName() + "_sub");
2832   auto sub_param = sub_builder.AddInstruction(
2833       HloInstruction::CreateParameter(0, shape, "sub_param"));
2834   auto one = sub_builder.AddInstruction(
2835       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
2836   auto ones = sub_builder.AddInstruction(
2837       HloInstruction::CreateBroadcast(shape, one, {}));
2838   auto add = sub_builder.AddInstruction(
2839       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, sub_param, ones));
2840 
2841   auto module = CreateNewVerifiedModule();
2842   auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
2843   sub_computation->CreateFusionInstruction({add, ones},
2844                                            HloInstruction::FusionKind::kLoop);
2845 
2846   // Build entry-computation with kCall which calls 'sub_computation'.
2847   auto builder = HloComputation::Builder(TestName());
2848 
2849   auto param = builder.AddInstruction(
2850       HloInstruction::CreateParameter(0, shape, "param"));
2851   auto reverse =
2852       builder.AddInstruction(HloInstruction::CreateReverse(shape, param, {0}));
2853   auto call = builder.AddInstruction(
2854       HloInstruction::CreateCall(shape, {reverse}, sub_computation));
2855   module->AddEntryComputation(builder.Build());
2856 
2857   auto dataflow_analysis = RunAnalysis(*module);
2858 
2859   EXPECT_TRUE(
2860       dataflow_analysis->CanShareOperandBufferWithUser(reverse, {}, call, {}));
2861 }
2862 
TEST_F(CanShareOperandBufferWithUserTest,ConcatSliceWithElementwise)2863 TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) {
2864   const char* kModule = R"(
2865     HloModule test
2866 
2867     fused_computation {
2868       p0 = f32[10,20] parameter(0)
2869       p1 = f32[10,20] parameter(1)
2870       p2 = f32[10,10] parameter(2)
2871       p3 = f32[10,10] parameter(3)
2872       add0 = f32[10, 20] add(p0, p1)
2873       sub0 = f32[10, 10] subtract(p2, p3)
2874       reshape0 = f32[200] reshape(add0)
2875       reshape1 = f32[100] reshape(sub0)
2876       concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
2877       slice0 = f32[200] slice(concat0), slice={[0:200]}
2878       slice1 = f32[100] slice(concat0), slice={[200:300]}
2879       ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
2880     }
2881 
2882     ENTRY test {
2883       p0 = f32[10,20] parameter(0)
2884       p1 = f32[10,20] parameter(1)
2885       p2 = f32[10,10] parameter(2)
2886       p3 = f32[10,10] parameter(3)
2887       ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
2888     }
2889   )";
2890   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2891   auto* fusion = module->entry_computation()->root_instruction();
2892   auto* param0 = module->entry_computation()->parameter_instruction(0);
2893   auto* param1 = module->entry_computation()->parameter_instruction(1);
2894   auto* param2 = module->entry_computation()->parameter_instruction(2);
2895   auto* param3 = module->entry_computation()->parameter_instruction(3);
2896 
2897   auto dataflow_analysis = RunAnalysis(*module);
2898   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2899                                                                fusion, {0}));
2900   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2901                                                                fusion, {0}));
2902   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param2, {},
2903                                                                fusion, {1}));
2904   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param3, {},
2905                                                                fusion, {1}));
2906   // Tensors of different sizes cannot share buffer.
2907   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2908                                                                 fusion, {1}));
2909 }
2910 
TEST_F(CanShareOperandBufferWithUserTest,ConcatSliceNegativeTest)2911 TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) {
2912   const char* kModule = R"(
2913     HloModule test
2914 
2915     fused_computation {
2916       // p0 has multiple transitive uses fed to concat. So, p0 cannot share
2917       // buffer with outputs because the aliased output could be written before
2918       // all the uses of p0 are finished.
2919       p0 = f32[100] parameter(0)
2920       p1 = f32[100] parameter(1)
2921       add0 = f32[100] add(p0, p1)
2922       concat0 = f32[200] concatenate(p0, add0), dimensions={0}
2923       slice0 = f32[100] slice(concat0), slice={[0:100]}
2924       slice1 = f32[100] slice(concat0), slice={[100:200]}
2925       ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1)
2926     }
2927 
2928     ENTRY test {
2929       p0 = f32[100] parameter(0)
2930       p1 = f32[100] parameter(1)
2931       ROOT fusion = (f32[100], f32[100]) fusion(p0, p1),
2932                         kind=kInput, calls=fused_computation
2933     }
2934   )";
2935   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2936   auto* fusion = module->entry_computation()->root_instruction();
2937   auto* param0 = module->entry_computation()->parameter_instruction(0);
2938   auto* param1 = module->entry_computation()->parameter_instruction(1);
2939 
2940   auto dataflow_analysis = RunAnalysis(*module);
2941   // p0 cannot share with either fusion{0} or fusion{1}.
2942   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2943                                                                 fusion, {0}));
2944   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2945                                                                 fusion, {1}));
2946   // p1 cannot share with fusion{0} because we're not sure about their
2947   // relationship.
2948   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2949                                                                 fusion, {0}));
2950   // p1 can share with fusion{1} because they will be executed in an
2951   // elementwise manner.
2952   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2953                                                                fusion, {1}));
2954 }
2955 
TEST_F(CanShareOperandBufferWithUserTest,MultipleConcatenates)2956 TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) {
2957   const char* kModule = R"(
2958     HloModule test
2959 
2960     fused_computation {
2961       p0 = f32[100] parameter(0)
2962       p1 = f32[100] parameter(1)
2963       add0 = f32[100] add(p0, p1)
2964       sub0 = f32[100] subtract(p1, p1)
2965       concat0 = f32[200] concatenate(p0, add0), dimensions={0}
2966       slice0 = f32[100] slice(concat0), slice={[0:100]}
2967       slice1 = f32[100] slice(concat0), slice={[100:200]}
2968       concat1 = f32[200] concatenate(p0, sub0), dimensions={0}
2969       slice2 = f32[100] slice(concat1), slice={[0:100]}
2970       slice3 = f32[100] slice(concat1), slice={[100:200]}
2971       ROOT tuple = (f32[100], f32[100], f32[100], f32[100])
2972                        tuple(slice0, slice1, slice2, slice3)
2973     }
2974 
2975     ENTRY test {
2976       p0 = f32[100] parameter(0)
2977       p1 = f32[100] parameter(1)
2978       ROOT fusion = (f32[100], f32[100], f32[100], f32[100])
2979           fusion(p0, p1), kind=kInput, calls=fused_computation
2980     }
2981   )";
2982   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
2983   auto* fusion = module->entry_computation()->root_instruction();
2984   auto* param0 = module->entry_computation()->parameter_instruction(0);
2985   auto* param1 = module->entry_computation()->parameter_instruction(1);
2986 
2987   auto dataflow_analysis = RunAnalysis(*module);
2988   // p0 cannot share.
2989   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2990                                                                 fusion, {0}));
2991   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2992                                                                 fusion, {1}));
2993   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2994                                                                 fusion, {2}));
2995   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param0, {},
2996                                                                 fusion, {3}));
2997   // p1 can share with either fusion{1} or fusion{3}.
2998   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
2999                                                                fusion, {1}));
3000   EXPECT_TRUE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
3001                                                                fusion, {3}));
3002   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
3003                                                                 fusion, {0}));
3004   EXPECT_FALSE(dataflow_analysis->CanShareOperandBufferWithUser(param1, {},
3005                                                                 fusion, {2}));
3006 }
3007 
3008 using GetInPlaceInputOutputPairsTest = HloTestBase;
3009 
TEST_F(GetInPlaceInputOutputPairsTest,DUS)3010 TEST_F(GetInPlaceInputOutputPairsTest, DUS) {
3011   const char* kModule = R"(
3012     HloModule test
3013 
3014     ENTRY test {
3015       p0 = f32[10] parameter(0)
3016       p1 = f32[5] parameter(1)
3017       p2 = s32[] parameter(2)
3018       ROOT dus = f32[10] dynamic-update-slice(p0, p1, p2)
3019     }
3020   )";
3021   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3022   HloInstruction* dus = module->entry_computation()->root_instruction();
3023 
3024   auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(dus);
3025   std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3026   expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
3027   EXPECT_EQ(in_place_pairs, expected_pairs);
3028 }
3029 
TEST_F(GetInPlaceInputOutputPairsTest,DUSFusion)3030 TEST_F(GetInPlaceInputOutputPairsTest, DUSFusion) {
3031   const char* kModule = R"(
3032     HloModule test
3033 
3034     fused_computation {
3035       p0 = f32[10] parameter(0)
3036       p1 = f32[5] parameter(1)
3037       p2 = s32[] parameter(2)
3038       ROOT dus = f32[10] dynamic-update-slice(p0, p1, p2)
3039     }
3040 
3041     ENTRY test {
3042       p0 = f32[10] parameter(0)
3043       p1 = f32[5] parameter(1)
3044       p2 = s32[] parameter(2)
3045       ROOT fusion = f32[10] fusion(p0, p1, p2), kind=kLoop, calls=fused_computation
3046     }
3047   )";
3048   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3049   HloInstruction* fusion = module->entry_computation()->root_instruction();
3050 
3051   auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3052   std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3053   expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
3054   EXPECT_EQ(in_place_pairs, expected_pairs);
3055 }
3056 
TEST_F(GetInPlaceInputOutputPairsTest,NonDUSFusion)3057 TEST_F(GetInPlaceInputOutputPairsTest, NonDUSFusion) {
3058   const char* kModule = R"(
3059     HloModule test
3060 
3061     fused_computation {
3062       p0 = f32[10] parameter(0)
3063       p1 = f32[10] parameter(1)
3064       ROOT add = f32[10] add(p0, p1)
3065     }
3066 
3067     ENTRY test {
3068       p0 = f32[10] parameter(0)
3069       p1 = f32[10] parameter(1)
3070       ROOT fusion = f32[10] fusion(p0, p1), kind=kLoop, calls=fused_computation
3071     }
3072   )";
3073   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3074   HloInstruction* fusion = module->entry_computation()->root_instruction();
3075 
3076   auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3077   EXPECT_THAT(in_place_pairs, IsEmpty());
3078 }
3079 
TEST_F(GetInPlaceInputOutputPairsTest,NestedDUSFusion)3080 TEST_F(GetInPlaceInputOutputPairsTest, NestedDUSFusion) {
3081   const char* kModule = R"(
3082     HloModule test
3083 
3084     fused_computation1 {
3085       p0 = f32[10] parameter(0)
3086       p1 = f32[5] parameter(1)
3087       p2 = s32[] parameter(2)
3088       ROOT dus = f32[10] dynamic-update-slice(p0, p1, p2)
3089     }
3090 
3091     fused_computation2 {
3092       p0 = f32[10] parameter(0)
3093       p1 = f32[5] parameter(1)
3094       p2 = s32[] parameter(2)
3095       ROOT fusion = f32[10] fusion(p0, p1, p2), kind=kLoop, calls=fused_computation1
3096     }
3097 
3098     ENTRY test {
3099       p0 = f32[10] parameter(0)
3100       p1 = f32[5] parameter(1)
3101       p2 = s32[] parameter(2)
3102       ROOT fusion = f32[10] fusion(p0, p1, p2), kind=kLoop, calls=fused_computation2
3103     }
3104   )";
3105   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3106   HloInstruction* fusion = module->entry_computation()->root_instruction();
3107 
3108   auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3109   std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3110   expected_pairs.push_back({HloOperandIndex{0, {}}, {}});
3111   EXPECT_EQ(in_place_pairs, expected_pairs);
3112 }
3113 
TEST_F(GetInPlaceInputOutputPairsTest,NestedMultiOutputDUSFusion)3114 TEST_F(GetInPlaceInputOutputPairsTest, NestedMultiOutputDUSFusion) {
3115   const char* kModule = R"(
3116     HloModule test
3117 
3118     fused_computation1 {
3119       p0 = s32[] parameter(0)
3120       p1 = (f32[5],f32[10]) parameter(1)
3121       gte0 = f32[5] get-tuple-element(p1), index=0
3122       gte1 = f32[10] get-tuple-element(p1), index=1
3123       dus = f32[10] dynamic-update-slice(gte1, gte0, p0)
3124       negate = f32[5] negate(gte0)
3125       ROOT tuple = (f32[5],f32[10]) tuple(negate, dus)
3126     }
3127 
3128     fused_computation2 {
3129       p0 = f32[5] parameter(0)
3130       p1 = (f32[10],s32[]) parameter(1)
3131       gte0 = f32[10] get-tuple-element(p1), index=0
3132       gte1 = s32[] get-tuple-element(p1), index=1
3133       in_tuple = (f32[5],f32[10]) tuple(p0, gte0)
3134       inner_fusion = (f32[5],f32[10]) fusion(gte1, in_tuple), kind=kLoop, calls=fused_computation1
3135       fusion_gte0 = f32[5] get-tuple-element(inner_fusion), index=0
3136       fusion_gte1 = f32[10] get-tuple-element(inner_fusion), index=1
3137       negate = f32[5] negate(p0)
3138       ROOT tuple = (f32[5],f32[5],f32[10]) tuple(negate, fusion_gte0, fusion_gte1)
3139     }
3140 
3141     ENTRY test {
3142       p0 = f32[5] parameter(0)
3143       p1 = (f32[10],s32[]) parameter(1)
3144       ROOT fusion = (f32[5],f32[5],f32[10]) fusion(p0, p1), kind=kLoop, calls=fused_computation2
3145     }
3146   )";
3147   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kModule));
3148   HloInstruction* fusion = module->entry_computation()->root_instruction();
3149   HloInstruction* inner_fusion = FindInstruction(module.get(), "inner_fusion");
3150 
3151   auto inner_in_place_pairs =
3152       HloDataflowAnalysis::GetInPlaceInputOutputPairs(inner_fusion);
3153   std::vector<std::pair<HloOperandIndex, ShapeIndex>> inner_expected_pairs;
3154   inner_expected_pairs.push_back({HloOperandIndex{1, {1}}, {1}});
3155   EXPECT_EQ(inner_in_place_pairs, inner_expected_pairs);
3156 
3157   auto in_place_pairs = HloDataflowAnalysis::GetInPlaceInputOutputPairs(fusion);
3158   std::vector<std::pair<HloOperandIndex, ShapeIndex>> expected_pairs;
3159   expected_pairs.push_back({HloOperandIndex{1, {0}}, {2}});
3160   EXPECT_EQ(in_place_pairs, expected_pairs);
3161 }
3162 
3163 }  // namespace
3164 }  // namespace xla
3165