xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_instruction_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17 
18 #include <optional>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/protobuf_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39 
40 namespace xla {
41 namespace {
42 
43 using ::testing::ElementsAre;
44 using ::testing::UnorderedElementsAre;
45 
46 class HloInstructionTest : public HloTestBase {
47  protected:
48   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
49 };
50 
51 // Simple visitor that collects the number of users and operands for certain HLO
52 // nodes. It also verifies some of the DFS visiting invariants (operands visited
53 // before their users, nodes not visited twice, etc.)
54 class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
55  public:
DefaultAction(HloInstruction * hlo_instruction)56   Status DefaultAction(HloInstruction* hlo_instruction) override {
57     return Unimplemented("not implemented %s",
58                          HloOpcodeString(hlo_instruction->opcode()));
59   }
60 
HandleParameter(HloInstruction * parameter)61   Status HandleParameter(HloInstruction* parameter) override {
62     EXPECT_FALSE(count_.contains(parameter));
63     count_[parameter] = GetCountsForNode(parameter);
64     return OkStatus();
65   }
66 
HandleConstant(HloInstruction * constant)67   Status HandleConstant(HloInstruction* constant) override {
68     EXPECT_FALSE(count_.contains(constant));
69     count_[constant] = GetCountsForNode(constant);
70     return OkStatus();
71   }
72 
HandleAdd(HloInstruction * add)73   Status HandleAdd(HloInstruction* add) override {
74     auto lhs = add->operand(0);
75     auto rhs = add->operand(1);
76     EXPECT_FALSE(count_.contains(add));
77     EXPECT_TRUE(count_.contains(lhs));
78     EXPECT_TRUE(count_.contains(rhs));
79     count_[add] = GetCountsForNode(add);
80     return OkStatus();
81   }
82 
HandleNegate(HloInstruction * negate)83   Status HandleNegate(HloInstruction* negate) override {
84     auto operand = negate->operand(0);
85     EXPECT_FALSE(count_.contains(negate));
86     EXPECT_TRUE(count_.contains(operand));
87     count_[negate] = GetCountsForNode(negate);
88     return OkStatus();
89   }
90 
HandleMap(HloInstruction * map)91   Status HandleMap(HloInstruction* map) override {
92     EXPECT_FALSE(count_.contains(map));
93     for (HloInstruction* arg : map->operands()) {
94       EXPECT_TRUE(count_.contains(arg));
95     }
96     count_[map] = GetCountsForNode(map);
97     return OkStatus();
98   }
99 
HandleReduce(HloInstruction * reduce)100   Status HandleReduce(HloInstruction* reduce) override {
101     auto arg = reduce->operand(0);
102     auto init_value = reduce->operand(1);
103     EXPECT_FALSE(count_.contains(reduce));
104     EXPECT_TRUE(count_.contains(arg));
105     EXPECT_TRUE(count_.contains(init_value));
106     count_[reduce] = GetCountsForNode(reduce);
107     return OkStatus();
108   }
109 
NumOperands(const HloInstruction * node)110   int64_t NumOperands(const HloInstruction* node) {
111     auto count_iterator = count_.find(node);
112     EXPECT_NE(count_.end(), count_iterator);
113     return count_iterator->second.operand_count;
114   }
115 
NumUsers(const HloInstruction * node)116   int64_t NumUsers(const HloInstruction* node) {
117     auto count_iterator = count_.find(node);
118     EXPECT_NE(count_.end(), count_iterator);
119     return count_iterator->second.user_count;
120   }
121 
122  private:
123   struct NumOpsAndUsers {
124     int64_t operand_count;
125     int64_t user_count;
126   };
127 
128   // Helper function to count operands and users for the given HLO.
GetCountsForNode(const HloInstruction * node)129   NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
130     NumOpsAndUsers counts{node->operand_count(), node->user_count()};
131     return counts;
132   }
133 
134   // Counters for HLOs. Maps HLO to a NumOpsAndUsers.
135   absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
136 };
137 
TEST_F(HloInstructionTest,BasicProperties)138 TEST_F(HloInstructionTest, BasicProperties) {
139   auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
140 
141   EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
142   EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
143   EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
144   EXPECT_FALSE(parameter->operand_count());
145 }
146 
TEST_F(HloInstructionTest,UserWithTwoOperands)147 TEST_F(HloInstructionTest, UserWithTwoOperands) {
148   // [Param foo]----->  |-----|
149   //                    | Add |
150   // [Param bar]----->  |-----|
151   HloComputation::Builder builder(TestName());
152   auto foo =
153       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
154   auto bar =
155       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
156   auto add = builder.AddInstruction(
157       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
158   auto module = CreateNewVerifiedModule();
159   module->AddEntryComputation(builder.Build());
160 
161   EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
162   EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
163   EXPECT_THAT(bar->users(), UnorderedElementsAre(add));
164 
165   OpAndUserCollectingVisitor visitor;
166   ASSERT_IS_OK(add->Accept(&visitor));
167 
168   EXPECT_EQ(2, visitor.NumOperands(add));
169   EXPECT_EQ(0, visitor.NumUsers(add));
170   EXPECT_EQ(1, visitor.NumUsers(foo));
171   EXPECT_EQ(1, visitor.NumUsers(bar));
172 }
173 
TEST_F(HloInstructionTest,MultipleUsers)174 TEST_F(HloInstructionTest, MultipleUsers) {
175   //        [Param foo]
176   //       /     |     \
177   //      /      |      \     [Param bar]
178   //     /       |       \         |
179   //     |       |       |         |
180   //     V       V       V         V
181   //  -------  -------   -----------
182   //  | exp |  | exp |   |   add   |
183   //  -------  -------   -----------
184   HloComputation::Builder builder(TestName());
185   auto foo =
186       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
187   auto bar =
188       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
189   auto exp1 = builder.AddInstruction(
190       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
191   auto exp2 = builder.AddInstruction(
192       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
193   auto add = builder.AddInstruction(
194       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
195   auto module = CreateNewVerifiedModule();
196   module->AddEntryComputation(builder.Build());
197 
198   EXPECT_EQ(3, foo->user_count());
199   EXPECT_EQ(1, bar->user_count());
200   EXPECT_EQ(0, exp1->user_count());
201   EXPECT_EQ(0, exp2->user_count());
202   EXPECT_EQ(0, add->user_count());
203 
204   OpAndUserCollectingVisitor visitor;
205   ASSERT_IS_OK(add->Accept(&visitor));
206 
207   EXPECT_EQ(2, visitor.NumOperands(add));
208   EXPECT_EQ(3, visitor.NumUsers(foo));
209 }
210 
TEST_F(HloInstructionTest,RepeatedUser)211 TEST_F(HloInstructionTest, RepeatedUser) {
212   // Here we have a user 'add' nodes that uses the same HLO in both operands.
213   // Make sure we don't count it as two distinct users.
214   //
215   //        [Param foo]
216   //           |   |
217   //           |   |
218   //           |   |
219   //           V   V
220   //          -------
221   //          | add |
222   //          -------
223   HloComputation::Builder builder(TestName());
224   auto foo =
225       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
226   auto add = builder.AddInstruction(
227       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
228   auto module = CreateNewVerifiedModule();
229   module->AddEntryComputation(builder.Build());
230 
231   EXPECT_EQ(1, foo->user_count());
232 
233   // But 'add' still has two operands, even if both are the same HLO.
234   EXPECT_EQ(2, add->operand_count());
235 }
236 
TEST_F(HloInstructionTest,MultipleUsersAndOperands)237 TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
238   //        [param0]          [param1]
239   //           |                 |
240   //           |       [c0]      |
241   //           |        |        |
242   //           V        |        V
243   //        -------     |     -------
244   //        | add | <---^---> | add |
245   //        -------           -------
246   //           |                 |
247   //           \     -------     /
248   //            ---->| add |<----
249   //                 -------
250   HloComputation::Builder builder(TestName());
251   auto param0 = builder.AddInstruction(
252       HloInstruction::CreateParameter(0, r0f32_, "param0"));
253   auto param1 = builder.AddInstruction(
254       HloInstruction::CreateParameter(1, r0f32_, "param1"));
255   auto c0 = builder.AddInstruction(
256       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
257   auto addleft = builder.AddInstruction(
258       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
259   auto addright = builder.AddInstruction(
260       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
261   auto addtotal = builder.AddInstruction(
262       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
263   auto module = CreateNewVerifiedModule();
264   module->AddEntryComputation(builder.Build());
265 
266   OpAndUserCollectingVisitor visitor;
267   ASSERT_IS_OK(addtotal->Accept(&visitor));
268 
269   EXPECT_EQ(2, visitor.NumUsers(c0));
270   EXPECT_EQ(2, visitor.NumOperands(addleft));
271   EXPECT_EQ(2, visitor.NumOperands(addright));
272   EXPECT_EQ(2, visitor.NumOperands(addtotal));
273 }
274 
TEST_F(HloInstructionTest,MultipleUsersAndOperandsWithUnaryOps)275 TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
276   //        [param0]   [c0]   [param1]
277   //           |        |        |
278   //           |        V        |
279   //           |     -------     |
280   //           |     | neg |     |
281   //           |     -------     |
282   //           V        |        V
283   //        -------     |     -------
284   //        | add | <---^---> | add |
285   //        -------           -------
286   //           |                 |
287   //           \     -------     /
288   //            ---->| add |<----
289   //                 -------
290   //                    |
291   //                    V
292   //                 -------
293   //                 | neg |
294   //                 -------
295   HloComputation::Builder builder(TestName());
296   auto param0 = builder.AddInstruction(
297       HloInstruction::CreateParameter(0, r0f32_, "param0"));
298   auto param1 = builder.AddInstruction(
299       HloInstruction::CreateParameter(1, r0f32_, "param1"));
300   auto c0 = builder.AddInstruction(
301       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
302   auto neg1 = builder.AddInstruction(
303       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
304   auto addleft = builder.AddInstruction(
305       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1));
306   auto addright = builder.AddInstruction(
307       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1));
308   auto addtotal = builder.AddInstruction(
309       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
310   auto neg2 = builder.AddInstruction(
311       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
312   auto module = CreateNewVerifiedModule();
313   module->AddEntryComputation(builder.Build());
314 
315   OpAndUserCollectingVisitor visitor;
316   ASSERT_IS_OK(neg2->Accept(&visitor));
317 
318   EXPECT_EQ(1, visitor.NumUsers(c0));
319   EXPECT_EQ(2, visitor.NumUsers(neg1));
320   EXPECT_EQ(2, visitor.NumOperands(addleft));
321   EXPECT_EQ(2, visitor.NumOperands(addright));
322   EXPECT_EQ(2, visitor.NumOperands(addtotal));
323   EXPECT_EQ(1, visitor.NumOperands(neg2));
324   EXPECT_EQ(0, visitor.NumUsers(neg2));
325 }
326 
TEST_F(HloInstructionTest,TrivialMap)327 TEST_F(HloInstructionTest, TrivialMap) {
328   // This tests creating a trivial x+1 map as the only operation.
329   //
330   // param0[100x10] ---> (map x+1)
331   //
332   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
333   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
334   auto module = CreateNewVerifiedModule();
335 
336   // Builds an x+1.0 computation to use in a Map.
337   auto embedded_builder = HloComputation::Builder("f32+1");
338   auto param = embedded_builder.AddInstruction(
339       HloInstruction::CreateParameter(0, r0f32, "x"));
340   auto value = embedded_builder.AddInstruction(
341       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
342   embedded_builder.AddInstruction(
343       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
344   auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
345 
346   // Builds a parameter and feeds it to the map.
347   HloComputation::Builder builder(TestName());
348   auto param0 = builder.AddInstruction(
349       HloInstruction::CreateParameter(0, f32a100x10, "p"));
350   auto map = builder.AddInstruction(
351       HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
352   module->AddEntryComputation(builder.Build());
353 
354   OpAndUserCollectingVisitor visitor;
355   ASSERT_IS_OK(map->Accept(&visitor));
356 
357   // Check counts.  We aren't walking the mapper computation yet.
358   EXPECT_EQ(1, visitor.NumUsers(param0));
359   EXPECT_EQ(0, visitor.NumUsers(map));
360   EXPECT_EQ(1, visitor.NumOperands(map));
361 
362   // TODO(dehnert):  Add walking and counters for the wrapped computation.
363 }
364 
TEST_F(HloInstructionTest,TrivialReduce)365 TEST_F(HloInstructionTest, TrivialReduce) {
366   // This tests creating a trivial x+y reduce as the only operation.
367   //
368   // param0[100x10] ---> (reduce x+y)
369   //
370   Shape r0f32 = ShapeUtil::MakeShape(F32, {});
371   Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
372   Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
373 
374   // Builds an x+y computation to use in a Reduce.
375   auto embedded_builder = HloComputation::Builder("f32+f32");
376   auto paramx = embedded_builder.AddInstruction(
377       HloInstruction::CreateParameter(0, r0f32, "x"));
378   auto paramy = embedded_builder.AddInstruction(
379       HloInstruction::CreateParameter(1, r0f32, "y"));
380   embedded_builder.AddInstruction(
381       HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
382   auto module = CreateNewVerifiedModule();
383   auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
384 
385   // Builds a parameter and an initial value and feeds them to the reduce.
386   HloComputation::Builder builder(TestName());
387   auto param0 = builder.AddInstruction(
388       HloInstruction::CreateParameter(0, f32a100x10, "p"));
389   auto const0 = builder.AddInstruction(
390       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
391   builder.AddInstruction(
392       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
393   auto reduce = builder.AddInstruction(
394       HloInstruction::CreateReduce(f32v100, param0, const0,
395                                    /*dimensions_to_reduce=*/{1}, add_f32));
396   module->AddEntryComputation(builder.Build());
397 
398   OpAndUserCollectingVisitor visitor;
399   ASSERT_IS_OK(reduce->Accept(&visitor));
400 
401   // Check counts.  We aren't walking the reducer computation.
402   EXPECT_EQ(1, visitor.NumUsers(param0));
403   EXPECT_EQ(1, visitor.NumUsers(const0));
404   EXPECT_EQ(0, visitor.NumUsers(reduce));
405   EXPECT_EQ(2, visitor.NumOperands(reduce));
406 }
407 
TEST_F(HloInstructionTest,ReplaceUseInBinaryOps)408 TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
409   // Construct a graph of a few binary ops using two different
410   // parameters. Replace one of the parameters with the other parameter in one
411   // of the instructions.
412   HloComputation::Builder builder(TestName());
413   auto foo =
414       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
415   auto bar =
416       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
417   auto add_foobar = builder.AddInstruction(
418       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
419   auto add_foofoo = builder.AddInstruction(
420       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
421   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
422                                                       add_foobar, add_foofoo));
423   auto module = CreateNewVerifiedModule();
424   module->AddEntryComputation(builder.Build());
425 
426   EXPECT_EQ(2, foo->user_count());
427   EXPECT_EQ(1, bar->user_count());
428 
429   // Replace the use of foo in add_foofoo with bar.
430   ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar));
431 
432   EXPECT_EQ(1, foo->user_count());
433   EXPECT_EQ(2, bar->user_count());
434 
435   EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
436   EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
437 
438   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
439   EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
440   EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar));
441 }
442 
TEST_F(HloInstructionTest,ReplaceUseInVariadicOp)443 TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
444   // Construct a tuple containing several parameters. Replace one parameter with
445   // another in the tuple.
446   HloComputation::Builder builder(TestName());
447   auto foo =
448       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
449   auto bar =
450       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
451   auto baz =
452       builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz"));
453 
454   auto tuple =
455       builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
456   auto add_foobar = builder.AddInstruction(
457       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
458   auto module = CreateNewVerifiedModule();
459   module->AddEntryComputation(builder.Build());
460 
461   EXPECT_EQ(2, foo->user_count());
462   EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
463 
464   // Replace the use of foo in tuple with bar.
465   ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar));
466 
467   EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
468 
469   // Both uses of foo in tuple should have been replaced with bar.
470   EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar));
471 }
472 
TEST_F(HloInstructionTest,ReplaceUseInUnaryOp)473 TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
474   // Construct a couple unary instructions which use a parameter. Replace the
475   // use of a parameter in one of the unary ops with the other parameter.
476   HloComputation::Builder builder(TestName());
477   auto foo =
478       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
479   auto bar =
480       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
481 
482   auto exp = builder.AddInstruction(
483       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
484   auto log = builder.AddInstruction(
485       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
486   auto module = CreateNewVerifiedModule();
487   module->AddEntryComputation(builder.Build());
488 
489   EXPECT_EQ(2, foo->user_count());
490   EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
491   EXPECT_EQ(0, bar->user_count());
492 
493   // Replace the use of foo in exp with bar.
494   ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar));
495 
496   // The use of foo in log should not have been affected.
497   EXPECT_EQ(1, foo->user_count());
498   EXPECT_THAT(foo->users(), UnorderedElementsAre(log));
499   EXPECT_THAT(log->operands(), ElementsAre(foo));
500 
501   // Bar should now be used in exp.
502   EXPECT_EQ(1, bar->user_count());
503   EXPECT_EQ(*bar->users().begin(), exp);
504   EXPECT_EQ(1, exp->operands().size());
505   EXPECT_EQ(*exp->operands().begin(), bar);
506 }
507 
TEST_F(HloInstructionTest,ReplaceAllUsesWithInBinaryOps)508 TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
509   // Construct a simple graph of a few binary ops using two different
510   // parameters. Replace all uses of one of the parameters with the other
511   // parameter.
512   HloComputation::Builder builder(TestName());
513   auto foo =
514       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
515   auto bar =
516       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
517   auto add_foobar = builder.AddInstruction(
518       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
519   auto add_foofoo = builder.AddInstruction(
520       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
521   builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
522                                                       add_foobar, add_foofoo));
523   auto module = CreateNewVerifiedModule();
524   module->AddEntryComputation(builder.Build());
525 
526   EXPECT_EQ(2, foo->user_count());
527   EXPECT_EQ(1, bar->user_count());
528 
529   // Replace all uses of foo with bar.
530   ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
531 
532   EXPECT_EQ(0, foo->user_count());
533   EXPECT_EQ(2, bar->user_count());
534 
535   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
536 }
537 
TEST_F(HloInstructionTest,ReplaceAllUsesInMultipleOps)538 TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
539   // Construct a graph containing several ops (a unary, binary, and variadic)
540   // which use two parameters. Replace all uses of one of the parameters with
541   // the other parameter.
542   HloComputation::Builder builder(TestName());
543   auto foo =
544       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
545   auto bar =
546       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
547 
548   auto add_foobar = builder.AddInstruction(
549       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
550   auto exp = builder.AddInstruction(
551       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
552   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
553   auto module = CreateNewVerifiedModule();
554   module->AddEntryComputation(builder.Build());
555 
556   EXPECT_EQ(3, foo->user_count());
557   EXPECT_EQ(2, bar->user_count());
558 
559   // Replace all uses of foo with bar.
560   ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
561 
562   EXPECT_EQ(0, foo->user_count());
563   EXPECT_EQ(3, bar->user_count());
564 
565   EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple));
566 }
567 
568 // Simple visitor that collects and post-processes each node in the graph.
569 class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
570  public:
NodeCollectorAndPostProcessor()571   NodeCollectorAndPostProcessor() {}
572 
Postprocess(HloInstruction * hlo)573   Status Postprocess(HloInstruction* hlo) override {
574     post_processed_nodes_.push_back(hlo);
575     return OkStatus();
576   }
577 
DefaultAction(HloInstruction * hlo_instruction)578   Status DefaultAction(HloInstruction* hlo_instruction) override {
579     visited_nodes_.push_back(hlo_instruction);
580     return OkStatus();
581   }
582 
visited_nodes()583   const std::vector<const HloInstruction*>& visited_nodes() {
584     return visited_nodes_;
585   }
586 
post_processed_nodes()587   const std::vector<const HloInstruction*>& post_processed_nodes() {
588     return post_processed_nodes_;
589   }
590 
591  private:
592   std::vector<const HloInstruction*> visited_nodes_;
593   std::vector<const HloInstruction*> post_processed_nodes_;
594 };
595 
596 // Returns true if "vec" contains distinct nodes.
Distinct(const std::vector<const HloInstruction * > & vec)597 bool Distinct(const std::vector<const HloInstruction*>& vec) {
598   std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
599   return distinct_nodes.size() == vec.size();
600 }
601 
TEST_F(HloInstructionTest,PostProcessAllVisitedNodes)602 TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
603   // Verifies all the nodes are visited and post-processed in the same order,
604   // and that each node is visited exactly once.
605   //
606   //    /--> exp --\
607   // foo            add
608   //    \--> log --/
609   HloComputation::Builder builder(TestName());
610   auto foo =
611       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
612   auto exp = builder.AddInstruction(
613       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
614   auto log = builder.AddInstruction(
615       HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
616   auto add = builder.AddInstruction(
617       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
618   auto module = CreateNewVerifiedModule();
619   module->AddEntryComputation(builder.Build());
620 
621   NodeCollectorAndPostProcessor visitor;
622   ASSERT_IS_OK(add->Accept(&visitor));
623   // Verifies all the nodes are visited and post-processed in the same order.
624   EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
625   // Verifies each node is visited exactly once.
626   EXPECT_TRUE(Distinct(visitor.visited_nodes()));
627 }
628 
TEST_F(HloInstructionTest,SingletonFusionOp)629 TEST_F(HloInstructionTest, SingletonFusionOp) {
630   HloComputation::Builder builder(TestName());
631   // Create a fusion instruction containing a single unary operation.
632   auto constant = builder.AddInstruction(
633       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
634   auto exp = builder.AddInstruction(
635       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
636   auto module = CreateNewVerifiedModule();
637   auto* computation = module->AddEntryComputation(builder.Build());
638   auto* fusion = computation->CreateFusionInstruction(
639       {exp}, HloInstruction::FusionKind::kLoop);
640 
641   EXPECT_THAT(fusion->operands(), ElementsAre(constant));
642   EXPECT_THAT(constant->users(), ElementsAre(fusion));
643 }
644 
TEST_F(HloInstructionTest,BinaryFusionOp)645 TEST_F(HloInstructionTest, BinaryFusionOp) {
646   HloComputation::Builder builder(TestName());
647   // Create a fusion instruction containing a single binary operation.
648   auto constant1 = builder.AddInstruction(
649       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
650   auto constant2 = builder.AddInstruction(
651       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
652   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
653       r0f32_, HloOpcode::kAdd, constant1, constant2));
654   auto module = CreateNewVerifiedModule();
655   auto* computation = module->AddEntryComputation(builder.Build());
656   auto* fusion = computation->CreateFusionInstruction(
657       {add}, HloInstruction::FusionKind::kLoop);
658 
659   EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
660   EXPECT_THAT(constant1->users(), ElementsAre(fusion));
661   EXPECT_THAT(constant2->users(), ElementsAre(fusion));
662 }
663 
TEST_F(HloInstructionTest,ChainFusionOp)664 TEST_F(HloInstructionTest, ChainFusionOp) {
665   HloComputation::Builder builder(TestName());
666   // Create a chain of fused unary ops.
667   auto constant = builder.AddInstruction(
668       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
669   auto exp1 = builder.AddInstruction(
670       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
671   auto exp2 = builder.AddInstruction(
672       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
673   auto exp3 = builder.AddInstruction(
674       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
675 
676   auto module = CreateNewVerifiedModule();
677   auto* computation = module->AddEntryComputation(builder.Build());
678   auto* fusion = computation->CreateFusionInstruction(
679       {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
680 
681   EXPECT_THAT(fusion->operands(), ElementsAre(constant));
682   EXPECT_THAT(constant->users(), ElementsAre(fusion));
683 }
684 
TEST_F(HloInstructionTest,PreserveMetadataInFusionAndClone)685 TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
686   HloComputation::Builder builder(TestName());
687   // Create a chain of fused unary ops.
688   auto constant = builder.AddInstruction(
689       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
690   auto exp1 = builder.AddInstruction(
691       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
692   auto exp2 = builder.AddInstruction(
693       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
694   OpMetadata metadata;
695   metadata.set_op_name("tf_op");
696   exp1->set_metadata(metadata);
697   exp2->set_metadata(metadata);
698 
699   auto module = CreateNewVerifiedModule();
700   auto* computation = module->AddEntryComputation(builder.Build());
701   auto* fusion = computation->CreateFusionInstruction(
702       {exp2, exp1}, HloInstruction::FusionKind::kLoop);
703 
704   EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
705   EXPECT_TRUE(protobuf_util::ProtobufEquals(
706       metadata, fusion->fused_expression_root()->metadata()));
707   EXPECT_TRUE(protobuf_util::ProtobufEquals(
708       metadata, fusion->fused_expression_root()->operand(0)->metadata()));
709 
710   auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {});
711   EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
712 }
713 
TEST_F(HloInstructionTest,BinaryCallOp)714 TEST_F(HloInstructionTest, BinaryCallOp) {
715   HloComputation::Builder builder(TestName());
716   // Create a call instruction containing a single binary operation.
717   auto constant1 = builder.AddInstruction(
718       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
719   auto constant2 = builder.AddInstruction(
720       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
721   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
722       r0f32_, HloOpcode::kAdd, constant1, constant2));
723   auto module = CreateNewVerifiedModule();
724   auto* computation = module->AddEntryComputation(builder.Build());
725   auto* call = computation->CreateCallInstruction({add});
726 
727   EXPECT_THAT(call->operands(), ElementsAre(constant1, constant2));
728   EXPECT_THAT(constant1->users(), ElementsAre(call));
729   EXPECT_THAT(constant2->users(), ElementsAre(call));
730 }
731 
TEST_F(HloInstructionTest,ChainCallOp)732 TEST_F(HloInstructionTest, ChainCallOp) {
733   HloComputation::Builder builder(TestName());
734   // Create a chain of called unary ops.
735   auto constant = builder.AddInstruction(
736       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
737   auto exp1 = builder.AddInstruction(
738       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
739   auto exp2 = builder.AddInstruction(
740       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
741   auto exp3 = builder.AddInstruction(
742       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
743 
744   auto module = CreateNewVerifiedModule();
745   auto* computation = module->AddEntryComputation(builder.Build());
746   auto* call = computation->CreateCallInstruction({exp3, exp2, exp1});
747 
748   EXPECT_THAT(call->operands(), ElementsAre(constant));
749   EXPECT_THAT(constant->users(), ElementsAre(call));
750 }
751 
TEST_F(HloInstructionTest,MultiOutputCallOp)752 TEST_F(HloInstructionTest, MultiOutputCallOp) {
753   HloComputation::Builder builder(TestName());
754   // Create a chain of called unary ops.
755   auto constant = builder.AddInstruction(
756       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
757   auto exp1 = builder.AddInstruction(
758       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
759   auto exp2 = builder.AddInstruction(
760       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
761   auto exp3 = builder.AddInstruction(
762       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
763   auto exp4 = builder.AddInstruction(
764       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
765   auto add = builder.AddInstruction(
766       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp3, exp4));
767 
768   auto module = CreateNewVerifiedModule();
769   auto* computation = module->AddEntryComputation(builder.Build());
770   auto* call = computation->CreateCallInstruction({exp3, exp2, exp1});
771   call->AppendInstructionIntoCalledComputation(exp4, /*add_output=*/true);
772 
773   EXPECT_THAT(call->operands(), ElementsAre(constant));
774   EXPECT_EQ(add->operand(0)->opcode(), HloOpcode::kGetTupleElement);
775   EXPECT_THAT(add->operand(0)->operands(), ElementsAre(call));
776   EXPECT_EQ(add->operand(1)->opcode(), HloOpcode::kGetTupleElement);
777   EXPECT_THAT(add->operand(1)->operands(), ElementsAre(call));
778 }
779 
TEST_F(HloInstructionTest,AsyncOp)780 TEST_F(HloInstructionTest, AsyncOp) {
781   HloComputation::Builder builder(TestName());
782   // Create a call instruction containing a single binary operation.
783   auto constant1 = builder.AddInstruction(
784       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
785   auto constant2 = builder.AddInstruction(
786       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
787   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
788       r0f32_, HloOpcode::kAdd, constant1, constant2));
789   auto module = CreateNewVerifiedModule();
790   auto* computation = module->AddEntryComputation(builder.Build());
791   TF_ASSERT_OK_AND_ASSIGN(
792       auto* async_done,
793       computation->CreateAsyncInstructions(
794           add, {ShapeUtil::MakeScalarShape(U32)}, "parallel_thread"));
795   auto* async_start = async_done->operand(0);
796 
797   EXPECT_EQ(async_start->shape().tuple_shapes_size(), 3);
798   EXPECT_EQ(async_start->async_execution_thread(), "parallel_thread");
799   EXPECT_EQ(async_done->async_execution_thread(), "parallel_thread");
800   EXPECT_TRUE(ShapeUtil::Equal(async_start->shape().tuple_shapes(2),
801                                ShapeUtil::MakeScalarShape(U32)));
802   EXPECT_EQ(async_start->async_wrapped_computation()->execution_thread(),
803             "parallel_thread");
804   EXPECT_EQ(async_done->async_wrapped_computation()->execution_thread(),
805             "parallel_thread");
806   EXPECT_THAT(async_start->operands(), ElementsAre(constant1, constant2));
807   EXPECT_THAT(constant1->users(), ElementsAre(async_start));
808   EXPECT_THAT(constant2->users(), ElementsAre(async_start));
809   EXPECT_EQ(computation->root_instruction(), async_done);
810 }
811 
TEST_F(HloInstructionTest,PreserveOutfeedShapeThroughClone)812 TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
813   HloComputation::Builder builder(TestName());
814   auto constant = builder.AddInstruction(
815       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
816           {1, 2},
817           {3, 4},
818       })));
819   auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
820   auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
821   auto token = builder.AddInstruction(HloInstruction::CreateToken());
822   auto outfeed10 = builder.AddInstruction(
823       HloInstruction::CreateOutfeed(shape10, constant, token, ""));
824   auto outfeed01 = builder.AddInstruction(
825       HloInstruction::CreateOutfeed(shape01, constant, token, ""));
826 
827   auto clone01 = builder.AddInstruction(outfeed01->Clone());
828   auto clone10 = builder.AddInstruction(outfeed10->Clone());
829 
830   EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
831   EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
832 }
833 
TEST_F(HloInstructionTest,PreserveTupleShapeThroughClone)834 TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
835   HloComputation::Builder builder(TestName());
836   auto* constant = builder.AddInstruction(
837       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
838           {1, 2},
839           {3, 4},
840       })));
841   auto* tuple =
842       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
843   *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0})
844        ->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
845   *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1})
846        ->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
847   auto tuple_clone = tuple->Clone();
848   EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
849 }
850 
TEST_F(HloInstructionTest,PreserveShardingThroughCompatibleClone)851 TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
852   HloSharding sharding = HloSharding::AssignDevice(5);
853   HloComputation::Builder builder(TestName());
854   auto* constant = builder.AddInstruction(
855       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
856           {1, 2},
857           {3, 4},
858       })));
859   auto* tuple =
860       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
861   tuple->set_sharding(sharding);
862   // Compatible with original shape as tuple tree structure and leaf ranks are
863   // identical
864   auto clone_shape = ShapeUtil::MakeShape(F32, {3, 3});
865   clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
866   auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
867   EXPECT_EQ(tuple_clone->sharding(), sharding);
868 }
869 
TEST_F(HloInstructionTest,DoNotPreserveShardingThroughTupleTreeIncompatibleClone)870 TEST_F(HloInstructionTest,
871        DoNotPreserveShardingThroughTupleTreeIncompatibleClone) {
872   HloSharding sharding = HloSharding::AssignDevice(5);
873   HloComputation::Builder builder(TestName());
874   auto* constant = builder.AddInstruction(
875       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
876           {1, 2},
877           {3, 4},
878       })));
879   auto* tuple =
880       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
881   tuple->set_sharding(sharding);
882   // Incompatible with original shape as tuple tree structure is different
883   auto clone_shape = ShapeUtil::MakeShape(F32, {2, 2});
884   clone_shape =
885       ShapeUtil::MakeTupleShape({clone_shape, clone_shape, clone_shape});
886   auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
887   EXPECT_FALSE(tuple_clone->has_sharding());
888 }
889 
TEST_F(HloInstructionTest,DoNotPreserveShardingThroughLeafRankIncompatibleClone)890 TEST_F(HloInstructionTest,
891        DoNotPreserveShardingThroughLeafRankIncompatibleClone) {
892   HloSharding sharding = HloSharding::AssignDevice(5);
893   HloComputation::Builder builder(TestName());
894   auto* constant = builder.AddInstruction(
895       HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
896           {1, 2},
897           {3, 4},
898       })));
899   auto* tuple =
900       builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
901   tuple->set_sharding(sharding);
902   // Incompatible with original shape as tuple tree structure is different
903   auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
904   clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
905   auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
906   EXPECT_FALSE(tuple_clone->has_sharding());
907 }
908 
TEST_F(HloInstructionTest,FusionOpWithCalledComputations)909 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
910   // Create a fusion instruction containing a single unary operation.
911   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
912   auto module = CreateNewVerifiedModule();
913 
914   auto make_map_computation = [&]() {
915     auto builder = HloComputation::Builder("FusionMap");
916     builder.AddInstruction(
917         HloInstruction::CreateParameter(0, scalar_shape, "param"));
918     return module->AddEmbeddedComputation(builder.Build());
919   };
920 
921   HloComputation* computation_x = make_map_computation();
922   HloComputation* computation_y = make_map_computation();
923 
924   HloComputation::Builder builder(TestName());
925   auto constant = builder.AddInstruction(
926       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
927   auto map_1_x = builder.AddInstruction(
928       HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
929   auto map_2_x = builder.AddInstruction(
930       HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
931   auto map_3_y = builder.AddInstruction(
932       HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
933   auto* computation = module->AddEntryComputation(builder.Build());
934 
935   auto* fusion = computation->CreateFusionInstruction(
936       {map_3_y}, HloInstruction::FusionKind::kLoop);
937   auto* fused_computation = fusion->fused_instructions_computation();
938   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
939 
940   fusion->FuseInstruction(map_2_x);
941   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
942 
943   fusion->FuseInstruction(map_1_x);
944   EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
945 }
946 
TEST_F(HloInstructionTest,ComplexFusionOp)947 TEST_F(HloInstructionTest, ComplexFusionOp) {
948   HloComputation::Builder builder(TestName());
949   // Fuse all instructions in complicated expression:
950   //
951   //   add = Add(C1, C2)
952   //   clamp = Clamp(C2, add, add)
953   //   exp = Exp(add)
954   //   mul = Mul(exp, C3)
955   //   sub = Sub(mul, clamp)
956   //   tuple = Tuple({sub, sub, mul, C1})
957   //
958   // Notable complexities are repeated operands in the same instruction,
959   // different shapes, use of value in different expressions.
960   auto c1 = builder.AddInstruction(
961       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
962   auto c2 = builder.AddInstruction(
963       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
964   auto c3 = builder.AddInstruction(
965       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
966 
967   auto add = builder.AddInstruction(
968       HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
969   auto clamp = builder.AddInstruction(
970       HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
971   auto exp = builder.AddInstruction(
972       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
973   auto mul = builder.AddInstruction(
974       HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
975   auto sub = builder.AddInstruction(
976       HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
977   auto tuple =
978       builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
979 
980   auto module = CreateNewVerifiedModule();
981   auto* computation = module->AddEntryComputation(builder.Build());
982   auto* fusion = computation->CreateFusionInstruction(
983       {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
984 
985   // Operands in the fusion instruction's operands() vector should be in the
986   // order in which their users were added fused.
987   EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
988   EXPECT_THAT(c1->users(), ElementsAre(fusion));
989 }
990 
991 // Convenience function for comparing two HloInstructions.
Identical(const HloInstruction & instruction1,const HloInstruction & instruction2)992 static bool Identical(const HloInstruction& instruction1,
993                       const HloInstruction& instruction2) {
994   // Verify Identical is reflexive for both instructions.
995   EXPECT_TRUE(instruction1.Identical(instruction1));
996   EXPECT_TRUE(instruction2.Identical(instruction2));
997 
998   bool is_equal = instruction1.Identical(instruction2);
999   // Verify Identical is symmetric.
1000   EXPECT_EQ(is_equal, instruction2.Identical(instruction1));
1001   return is_equal;
1002 }
1003 
1004 // Convenience function for comparing two HloInstructions for structural
1005 // equality.
StructuralEqual(const HloInstruction & instruction1,const HloInstruction & instruction2)1006 static bool StructuralEqual(const HloInstruction& instruction1,
1007                             const HloInstruction& instruction2) {
1008   auto eq_operand_shapes = [](const HloInstruction* a,
1009                               const HloInstruction* b) {
1010     return ShapeUtil::Equal(a->shape(), b->shape());
1011   };
1012   auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
1013     return *a == *b;
1014   };
1015 
1016   // Verify Identical is reflexive for both instructions.
1017   EXPECT_TRUE(
1018       instruction1.Identical(instruction1, eq_operand_shapes, eq_computations));
1019   EXPECT_TRUE(
1020       instruction2.Identical(instruction2, eq_operand_shapes, eq_computations));
1021 
1022   bool is_equal =
1023       instruction1.Identical(instruction2, eq_operand_shapes, eq_computations);
1024   // Verify Identical is symmetric.
1025   EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes,
1026                                              eq_computations));
1027   return is_equal;
1028 }
1029 
TEST_F(HloInstructionTest,IdenticalInstructions)1030 TEST_F(HloInstructionTest, IdenticalInstructions) {
1031   // Test HloInstruction::Identical with some subset of instructions types.
1032 
1033   // Create a set of random constant operands to use below. Make them matrices
1034   // so dimensions are interesting.
1035   auto operand1 = HloInstruction::CreateConstant(
1036       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
1037   auto operand2 = HloInstruction::CreateConstant(
1038       LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
1039   auto vector_operand = HloInstruction::CreateConstant(
1040       LiteralUtil::CreateR1<float>({42.0, 123.0}));
1041   Shape shape = operand1->shape();
1042 
1043   // Convenient short names for the operands.
1044   HloInstruction* op1 = operand1.get();
1045   HloInstruction* op2 = operand2.get();
1046 
1047   // Operations which only depend on their operands and opcode.
1048   EXPECT_TRUE(
1049       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
1050                 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
1051   EXPECT_FALSE(
1052       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
1053                 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
1054   EXPECT_FALSE(
1055       Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
1056                 *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
1057 
1058   // Tuples.
1059   EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}),
1060                         *HloInstruction::CreateTuple({op1, op2})));
1061   EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}),
1062                          *HloInstruction::CreateTuple({op2, op1})));
1063 
1064   // Broadcasts.
1065   EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
1066                         *HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
1067   EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
1068                          *HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
1069   Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
1070   Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
1071   EXPECT_FALSE(
1072       Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
1073                 *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
1074 
1075   // Binary operands.
1076   EXPECT_TRUE(Identical(
1077       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
1078       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
1079   EXPECT_FALSE(Identical(
1080       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
1081       *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
1082   EXPECT_FALSE(Identical(
1083       *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
1084       *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
1085 }
1086 
TEST_F(HloInstructionTest,IdenticalCallInstructions)1087 TEST_F(HloInstructionTest, IdenticalCallInstructions) {
1088   const char* const hlo_string = R"(
1089 HloModule Module
1090 
1091 subcomp1 (x: f32[]) -> f32[] {
1092   x = f32[] parameter(0)
1093   ROOT n = f32[] sine(x)
1094 }
1095 
1096 subcomp2 (x: f32[]) -> f32[] {
1097   x = f32[] parameter(0)
1098   ROOT n = f32[] cosine(x)
1099 }
1100 
1101 ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
1102   p = f32[] parameter(0)
1103   t1 = f32[] call(p), to_apply=subcomp1
1104   t2 = f32[] call(p), to_apply=subcomp1
1105   t3 = f32[] call(p), to_apply=subcomp2
1106   ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
1107  }
1108 )";
1109   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1110                           ParseAndReturnVerifiedModule(hlo_string));
1111 
1112   auto* root = module->entry_computation()->root_instruction();
1113   auto* t1 = root->operand(0);
1114   auto* t2 = root->operand(1);
1115   auto* t3 = root->operand(2);
1116 
1117   EXPECT_TRUE(StructuralEqual(*t1, *t2));
1118   EXPECT_FALSE(StructuralEqual(*t1, *t3));
1119 }
1120 
TEST_F(HloInstructionTest,FunctionVisitor)1121 TEST_F(HloInstructionTest, FunctionVisitor) {
1122   // Verify the function visitor HloInstruction::Accept visits all instructions
1123   // from a root properly given the following graph:
1124   //
1125   //        param
1126   //       /     \
1127   //    negate   exp
1128   //        \    /
1129   //         add
1130   const Shape f32 = ShapeUtil::MakeShape(F32, {});
1131   HloComputation::Builder builder(TestName());
1132   auto param =
1133       builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0"));
1134   auto negate = builder.AddInstruction(
1135       HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param));
1136   auto exp = builder.AddInstruction(
1137       HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
1138   auto add = builder.AddInstruction(
1139       HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
1140   auto module = CreateNewVerifiedModule();
1141   module->AddEntryComputation(builder.Build());
1142 
1143   int visit_num = 0;
1144   absl::flat_hash_map<HloInstruction*, int> visit_order;
1145   FunctionVisitor visitor([&visit_num, &visit_order](HloInstruction* inst) {
1146     EXPECT_FALSE(visit_order.contains(inst));
1147     visit_order[inst] = visit_num;
1148     visit_num++;
1149     return OkStatus();
1150   });
1151   EXPECT_IS_OK(add->Accept(&visitor));
1152 
1153   EXPECT_EQ(0, visit_order.at(param));
1154   // negate and exp can be visited in an arbitrary order.
1155   EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2);
1156   EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2);
1157   EXPECT_NE(visit_order.at(exp), visit_order.at(negate));
1158   EXPECT_EQ(3, visit_order.at(add));
1159 }
1160 
TEST_F(HloInstructionTest,FullyElementwise)1161 TEST_F(HloInstructionTest, FullyElementwise) {
1162   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1163   HloComputation::Builder builder(TestName());
1164   auto x =
1165       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1166   auto y =
1167       builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
1168   auto add = builder.AddInstruction(
1169       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
1170   auto module = CreateNewVerifiedModule();
1171   module->AddEntryComputation(builder.Build());
1172 
1173   EXPECT_TRUE(add->IsElementwise());
1174   for (int i = 0; i < add->operand_count(); ++i) {
1175     EXPECT_TRUE(add->IsElementwiseOnOperand(i));
1176   }
1177 }
1178 
TEST_F(HloInstructionTest,MapIsElementwise)1179 TEST_F(HloInstructionTest, MapIsElementwise) {
1180   auto module = CreateNewVerifiedModule();
1181   const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
1182   HloComputation::Builder builder(TestName());
1183   HloComputation::Builder map_builder("id");
1184   map_builder.AddInstruction(
1185       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
1186   auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1187   auto x =
1188       builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
1189   auto map = builder.AddInstruction(
1190       HloInstruction::CreateMap(r2f32, {x}, map_computation));
1191   module->AddEntryComputation(builder.Build());
1192 
1193   EXPECT_TRUE(map->IsElementwise());
1194 }
1195 
TEST_F(HloInstructionTest,PartiallyElementwise)1196 TEST_F(HloInstructionTest, PartiallyElementwise) {
1197   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1198   const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
1199 
1200   // Fused expression:
1201   //
1202   // p0     p1   p2   p3
1203   //   \   /    /     |
1204   //    mul    /      |
1205   //      \   /       |
1206   //       div     broadcast
1207   //          \    /
1208   //           max
1209   //
1210   // The fusion instruction is not elementwise on p3 because the broadcast is
1211   // not elementwise.
1212   HloComputation::Builder builder("PartiallyElementwise");
1213   HloInstruction* p0 =
1214       builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
1215   HloInstruction* p1 =
1216       builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
1217   HloInstruction* p2 =
1218       builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
1219   HloInstruction* p3 =
1220       builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
1221   HloInstruction* mul = builder.AddInstruction(
1222       HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
1223   HloInstruction* div = builder.AddInstruction(
1224       HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
1225   // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
1226   HloInstruction* broadcast =
1227       builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
1228   HloInstruction* max = builder.AddInstruction(
1229       HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
1230 
1231   auto module = CreateNewVerifiedModule();
1232   auto* computation = module->AddEntryComputation(builder.Build());
1233   HloInstruction* fusion = computation->CreateFusionInstruction(
1234       {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
1235   EXPECT_FALSE(fusion->IsElementwise());
1236   for (int64_t operand_idx = 0; operand_idx < fusion->operand_count();
1237        ++operand_idx) {
1238     const HloInstruction* operand = fusion->operand(operand_idx);
1239     if (operand == p3) {
1240       EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1241     } else {
1242       EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1243     }
1244   }
1245 }
1246 
TEST_F(HloInstructionTest,PartiallyElementwiseWithReuse)1247 TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
1248   // Fused expression:
1249   //         y
1250   //        /
1251   // x   broadcast
1252   //  \   /  |
1253   //   min   |
1254   //     \   /
1255   //      sub
1256   //
1257   const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1258   const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1259 
1260   HloComputation::Builder builder("PartiallyElementwiseWithReuse");
1261   HloInstruction* x =
1262       builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1263   HloInstruction* y =
1264       builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
1265   HloInstruction* broadcast =
1266       builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
1267   HloInstruction* min = builder.AddInstruction(
1268       HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
1269   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
1270       r1f32, HloOpcode::kSubtract, min, broadcast));
1271 
1272   auto module = CreateNewVerifiedModule();
1273   auto* computation = module->AddEntryComputation(builder.Build());
1274   HloInstruction* fusion = computation->CreateFusionInstruction(
1275       {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
1276   EXPECT_FALSE(fusion->IsElementwise());
1277   for (int64_t operand_idx = 0; operand_idx < fusion->operand_count();
1278        ++operand_idx) {
1279     if (fusion->operand(operand_idx) == y) {
1280       EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1281     } else {
1282       EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1283     }
1284   }
1285 }
1286 
TEST_F(HloInstructionTest,CloneOfFusionPreservesShape)1287 TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
1288   // Fused expression:
1289   //
1290   // x     y
1291   // |     |
1292   // |  transpose
1293   //  \   /
1294   //   dot
1295   //
1296   // Tests that shapes aren't mangled by Clone().
1297   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1298   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1299   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1300   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1301 
1302   HloComputation::Builder builder("TransposeDot");
1303   HloInstruction* x =
1304       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1305   HloInstruction* y =
1306       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1307   HloInstruction* reshape =
1308       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1309   DotDimensionNumbers dot_dnums;
1310   dot_dnums.add_lhs_contracting_dimensions(1);
1311   dot_dnums.add_rhs_contracting_dimensions(0);
1312   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1313       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1314 
1315   auto module = CreateNewVerifiedModule();
1316   auto* computation = module->AddEntryComputation(builder.Build());
1317   HloInstruction* fusion = computation->CreateFusionInstruction(
1318       {dot, reshape}, HloInstruction::FusionKind::kLoop);
1319 
1320   auto fusion2 = fusion->Clone();
1321   const HloInstruction* root = fusion->fused_expression_root();
1322   const HloInstruction* root2 = fusion2->fused_expression_root();
1323   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
1324   EXPECT_TRUE(
1325       ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
1326   EXPECT_TRUE(
1327       ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
1328   EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
1329                                root2->operand(1)->operand(0)->shape()));
1330   EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
1331 }
1332 
TEST_F(HloInstructionTest,FuseInstructionKeepsInstruction)1333 TEST_F(HloInstructionTest, FuseInstructionKeepsInstruction) {
1334   constexpr char kHloString[] = R"(
1335   HloModule test_module
1336   fused_add {
1337     p0 = f32[32,32]{1,0} parameter(0)
1338     p1 = f32[32,32]{1,0} parameter(1)
1339     ROOT add = f32[32,32]{1,0} add(p0, p1)
1340   }
1341 
1342   ENTRY reduce {
1343     p2 = f32[32,32]{1,0} parameter(0)
1344     p3 = f32[32,32]{1,0} parameter(1)
1345     c1 = f32[] constant(1)
1346     broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
1347     mul = f32[32,32]{1,0} multiply(p2, p3)
1348     ROOT add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
1349   })";
1350   TF_ASSERT_OK_AND_ASSIGN(auto module,
1351                           ParseAndReturnVerifiedModule(kHloString));
1352   HloInstruction* fused_add = module->entry_computation()->root_instruction();
1353   HloInstruction* mul = fused_add->mutable_operand(0);
1354   EXPECT_EQ(1, mul->user_count());
1355   fused_add->FuseInstruction(mul);
1356   EXPECT_EQ(0, mul->user_count());
1357   // The fused instruction is still present in the computation.
1358   EXPECT_EQ(fused_add->parent(), mul->parent());
1359 }
1360 
TEST_F(HloInstructionTest,FuseInstructionIntoMultiOutputKeepsInstruction)1361 TEST_F(HloInstructionTest, FuseInstructionIntoMultiOutputKeepsInstruction) {
1362   constexpr char kHloString[] = R"(
1363   HloModule test_module
1364   fused_add {
1365     p0 = f32[32,32]{1,0} parameter(0)
1366     p1 = f32[32,32]{1,0} parameter(1)
1367     ROOT add = f32[32,32]{1,0} add(p0, p1)
1368   }
1369 
1370   ENTRY reduce {
1371     p2 = f32[32,32]{1,0} parameter(0)
1372     p3 = f32[32,32]{1,0} parameter(1)
1373     c1 = f32[] constant(1)
1374     mul = f32[32,32]{1,0} multiply(p2, p3)
1375     broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
1376     add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
1377     ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(mul, add)
1378   })";
1379   TF_ASSERT_OK_AND_ASSIGN(auto module,
1380                           ParseAndReturnVerifiedModule(kHloString));
1381   HloInstruction* root = module->entry_computation()->root_instruction();
1382   HloInstruction* mul = root->mutable_operand(0);
1383   HloInstruction* fused_add = root->mutable_operand(1);
1384   EXPECT_EQ(2, mul->user_count());
1385   fused_add->FuseInstructionIntoMultiOutput(mul);
1386   EXPECT_EQ(0, mul->user_count());
1387   // The fused instruction is still present in the computation.
1388   EXPECT_EQ(root->parent(), mul->parent());
1389 }
1390 
TEST_F(HloInstructionTest,NoRedundantFusionOperandsAfterReplacingUse)1391 TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
1392   // Fused expression:
1393   //
1394   // x     y
1395   // |     |
1396   // |  transpose
1397   //  \   /
1398   //   dot
1399   const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
1400 
1401   HloComputation::Builder builder("TransposeDot");
1402   HloInstruction* x =
1403       builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
1404   HloInstruction* y =
1405       builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
1406   HloInstruction* reshape =
1407       builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
1408   DotDimensionNumbers dot_dnums;
1409   dot_dnums.add_lhs_contracting_dimensions(1);
1410   dot_dnums.add_rhs_contracting_dimensions(0);
1411   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1412       s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1413 
1414   auto module = CreateNewVerifiedModule();
1415   auto* computation = module->AddEntryComputation(builder.Build());
1416   HloInstruction* fusion = computation->CreateFusionInstruction(
1417       {dot, reshape}, HloInstruction::FusionKind::kLoop);
1418 
1419   EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
1420 
1421   EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
1422   EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
1423 }
1424 
TEST_F(HloInstructionTest,FusionEquality)1425 TEST_F(HloInstructionTest, FusionEquality) {
1426   auto module = CreateNewVerifiedModule();
1427   HloComputation::Builder builder(TestName());
1428 
1429   // Create two fusion instructions containing a single unary operation.
1430   auto parameter =
1431       builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
1432   auto exp = builder.AddInstruction(
1433       HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
1434   auto neg = builder.AddInstruction(
1435       HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
1436   auto* computation = module->AddEntryComputation(builder.Build());
1437   auto* fusion = computation->CreateFusionInstruction(
1438       {exp}, HloInstruction::FusionKind::kLoop);
1439   auto* fusion2 = computation->CreateFusionInstruction(
1440       {neg}, HloInstruction::FusionKind::kLoop);
1441   EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1442 
1443   auto clone = fusion->Clone();
1444   EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1445 }
1446 
TEST_F(HloInstructionTest,NestedFusionEquality)1447 TEST_F(HloInstructionTest, NestedFusionEquality) {
1448   auto module = CreateNewVerifiedModule();
1449   HloComputation::Builder builder(TestName());
1450 
1451   // Build a nested fusion computation.
1452   Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1453   auto a = builder.AddInstruction(HloInstruction::CreateConstant(
1454       LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
1455   auto b = builder.AddInstruction(HloInstruction::CreateConstant(
1456       LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1457   auto b_t = builder.AddInstruction(
1458       HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
1459   DotDimensionNumbers dot_dnums;
1460   dot_dnums.add_lhs_contracting_dimensions(1);
1461   dot_dnums.add_rhs_contracting_dimensions(0);
1462   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
1463       data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
1464   auto one = builder.AddInstruction(
1465       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1466   auto add_operand = builder.AddInstruction(
1467       HloInstruction::CreateBroadcast(data_shape, one, {}));
1468   auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1469       data_shape, HloOpcode::kAdd, dot, add_operand));
1470   auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1471       data_shape, HloOpcode::kSubtract, dot, add_operand));
1472   builder.AddInstruction(
1473       HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
1474   auto computation = module->AddEntryComputation(builder.Build());
1475 
1476   auto nested_fusion = computation->CreateFusionInstruction(
1477       {dot, b_t}, HloInstruction::FusionKind::kLoop);
1478 
1479   auto fusion = computation->CreateFusionInstruction(
1480       {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
1481   auto fusion2 = computation->CreateFusionInstruction(
1482       {sub, nested_fusion}, HloInstruction::FusionKind::kOutput);
1483   auto clone = fusion->Clone();
1484   EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1485   EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1486 }
1487 
TEST_F(HloInstructionTest,CloneSuffixNames)1488 TEST_F(HloInstructionTest, CloneSuffixNames) {
1489   // Test that the suffix string added to cloned instructions is not
1490   // duplicated. Rather a numeric incrementing value should be appended. That
1491   // is, we want "foo.clone2", not "foo.clone.clone".
1492 
1493   // Test cloning the same instruction multiple times.
1494   auto foo =
1495       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
1496   EXPECT_EQ(foo->Clone()->name(), "foo.clone");
1497   EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
1498   EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
1499 
1500   // Test custom suffixes.
1501   EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
1502   EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
1503   EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
1504 
1505   // Test instruction name with a dot.
1506   auto foo_baz = HloInstruction::CreateParameter(
1507       0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
1508   EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
1509 
1510   // Test incrementing a large number after the suffix.
1511   auto foo_clone234 = HloInstruction::CreateParameter(
1512       0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
1513   EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
1514 
1515   // Test a non-numeric string after the cloning suffix.
1516   auto foo_clonexyz = HloInstruction::CreateParameter(
1517       0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
1518   EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
1519 
1520   // Test a name with multiple appearances of the suffix.
1521   auto foo_clone_clone3 = HloInstruction::CreateParameter(
1522       0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
1523   EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
1524 }
1525 
TEST_F(HloInstructionTest,Stringification)1526 TEST_F(HloInstructionTest, Stringification) {
1527   // Tests stringification of a simple op, fusion, while, and conditional.
1528   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1529   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1530   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1531   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1532 
1533   HloComputation::Builder builder("TransposeDot");
1534   HloInstruction* x =
1535       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1536   HloInstruction* y =
1537       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1538   HloInstruction* reshape =
1539       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1540   DotDimensionNumbers dot_dnums;
1541   dot_dnums.add_lhs_contracting_dimensions(1);
1542   dot_dnums.add_rhs_contracting_dimensions(0);
1543   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1544       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1545 
1546   auto options = HloPrintOptions().set_print_metadata(false);
1547 
1548   EXPECT_EQ(dot->ToString(options),
1549             "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
1550             "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1551 
1552   auto options2 = HloPrintOptions()
1553                       .set_print_metadata(false)
1554                       .set_print_operand_shape(false)
1555                       .set_print_percent(false)
1556                       .set_include_layout_in_shapes(false);
1557 
1558   EXPECT_EQ(dot->ToString(options2),
1559             "dot = f32[5,20] dot(x, transpose), "
1560             "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1561 
1562   auto module = CreateNewVerifiedModule();
1563   auto* computation = module->AddEntryComputation(builder.Build());
1564 
1565   HloInstruction* loop = builder.AddInstruction(
1566       HloInstruction::CreateWhile(sout, computation, computation, x));
1567   EXPECT_EQ(loop->ToString(options),
1568             "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), "
1569             "condition=%TransposeDot, body=%TransposeDot");
1570 
1571   auto pred = builder.AddInstruction(
1572       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1573   HloInstruction* conditional =
1574       builder.AddInstruction(HloInstruction::CreateConditional(
1575           sout, pred, x, computation, x, computation));
1576   EXPECT_EQ(conditional->ToString(options),
1577             "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, "
1578             "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), "
1579             "true_computation=%TransposeDot, false_computation=%TransposeDot");
1580 }
1581 
TEST_F(HloInstructionTest,StringifyGather_0)1582 TEST_F(HloInstructionTest, StringifyGather_0) {
1583   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1584   Shape start_indices_tensor_shape =
1585       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
1586   Shape gather_result_shape =
1587       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
1588 
1589   HloComputation::Builder builder("Gather");
1590   HloInstruction* input = builder.AddInstruction(
1591       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1592   HloInstruction* start_indices =
1593       builder.AddInstruction(HloInstruction::CreateParameter(
1594           1, start_indices_tensor_shape, "start_indices"));
1595 
1596   HloInstruction* gather_instruction = builder.AddInstruction(
1597       HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1598                                    HloGatherInstruction::MakeGatherDimNumbers(
1599                                        /*offset_dims=*/{4, 5, 6, 7, 8},
1600                                        /*collapsed_slice_dims=*/{},
1601                                        /*start_index_map=*/{0, 1, 2, 3, 4},
1602                                        /*index_vector_dim=*/4),
1603                                    /*slice_sizes=*/{30, 29, 28, 27, 26},
1604                                    /*indices_are_sorted=*/false));
1605 
1606   auto module = CreateNewVerifiedModule();
1607   module->AddEntryComputation(builder.Build());
1608 
1609   EXPECT_EQ(gather_instruction->ToString(),
1610             "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1611             "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1612             "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
1613             "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1614             "start_index_map={0,1,2,3,4}, "
1615             "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
1616 }
1617 
TEST_F(HloInstructionTest,StringifyGather_1)1618 TEST_F(HloInstructionTest, StringifyGather_1) {
1619   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1620   Shape start_indices_tensor_shape =
1621       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1622   Shape gather_result_shape =
1623       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1624 
1625   HloComputation::Builder builder("Gather");
1626   HloInstruction* input = builder.AddInstruction(
1627       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1628   HloInstruction* start_indices =
1629       builder.AddInstruction(HloInstruction::CreateParameter(
1630           1, start_indices_tensor_shape, "start_indices"));
1631 
1632   HloInstruction* gather_instruction = builder.AddInstruction(
1633       HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1634                                    HloGatherInstruction::MakeGatherDimNumbers(
1635                                        /*offset_dims=*/{4, 5, 6, 7, 8},
1636                                        /*collapsed_slice_dims=*/{},
1637                                        /*start_index_map=*/{0, 1, 2, 3, 4},
1638                                        /*index_vector_dim=*/2),
1639                                    /*slice_sizes=*/{30, 29, 28, 27, 26},
1640                                    /*indices_are_sorted=*/false));
1641 
1642   auto module = CreateNewVerifiedModule();
1643   module->AddEntryComputation(builder.Build());
1644 
1645   EXPECT_EQ(gather_instruction->ToString(),
1646             "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1647             "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1648             "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
1649             "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1650             "start_index_map={0,1,2,3,4}, "
1651             "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
1652 }
1653 
TEST_F(HloInstructionTest,StringifyScatter)1654 TEST_F(HloInstructionTest, StringifyScatter) {
1655   Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1656   Shape scatter_indices_tensor_shape =
1657       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1658   Shape scatter_updates_shape =
1659       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1660 
1661   HloComputation::Builder builder("Scatter");
1662   HloInstruction* input = builder.AddInstruction(
1663       HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1664   HloInstruction* scatter_indices =
1665       builder.AddInstruction(HloInstruction::CreateParameter(
1666           1, scatter_indices_tensor_shape, "scatter_indices"));
1667   HloInstruction* scatter_updates =
1668       builder.AddInstruction(HloInstruction::CreateParameter(
1669           2, scatter_updates_shape, "scatter_updates"));
1670 
1671   HloComputation::Builder update_builder("Scatter.update");
1672   update_builder.AddInstruction(
1673       HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
1674   update_builder.AddInstruction(
1675       HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
1676 
1677   auto module = CreateNewVerifiedModule();
1678   auto* update_computation =
1679       module->AddEmbeddedComputation(update_builder.Build());
1680 
1681   HloInstruction* scatter_instruction =
1682       builder.AddInstruction(HloInstruction::CreateScatter(
1683           input_tensor_shape, input, scatter_indices, scatter_updates,
1684           update_computation,
1685           HloScatterInstruction::MakeScatterDimNumbers(
1686               /*update_window_dims=*/{4, 5, 6, 7, 8},
1687               /*inserted_window_dims=*/{},
1688               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
1689               /*index_vector_dim=*/2),
1690           /*indices_are_sorted=*/false,
1691           /*unique_indices=*/false));
1692   module->AddEntryComputation(builder.Build());
1693 
1694   EXPECT_EQ(
1695       scatter_instruction->ToString(),
1696       "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
1697       "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1698       "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
1699       "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
1700       "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
1701       "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
1702       "to_apply=%Scatter.update");
1703 }
1704 
TEST_F(HloInstructionTest,StringifyAsyncOps)1705 TEST_F(HloInstructionTest, StringifyAsyncOps) {
1706   const Shape s1 = ShapeUtil::MakeShape(F32, {10});
1707   const Shape s2 = ShapeUtil::MakeShape(F32, {20});
1708   const Shape s_tuple = ShapeUtil::MakeTupleShape(
1709       {ShapeUtil::MakeTupleShape({s1}), s2, ShapeUtil::MakeShape(S32, {})});
1710 
1711   HloComputation::Builder async_builder("AsyncOp");
1712   HloInstruction* param = async_builder.AddInstruction(
1713       HloInstruction::CreateParameter(0, s1, "p0"));
1714   async_builder.AddInstruction(
1715       HloInstruction::CreateCustomCall(s2, {param},
1716                                        /*custom_call_target=*/"foo"));
1717   std::unique_ptr<HloComputation> async_computation = async_builder.Build();
1718 
1719   HloComputation::Builder entry_builder("Entry");
1720   HloInstruction* entry_param = entry_builder.AddInstruction(
1721       HloInstruction::CreateParameter(0, s1, "p0"));
1722   HloInstruction* async_start =
1723       entry_builder.AddInstruction(HloInstruction::CreateAsyncStart(
1724           s_tuple, {entry_param}, async_computation.get(),
1725           /*async_group_id=*/std::nullopt,
1726           /*async_execution_thread=*/"parallel_thread"));
1727   HloInstruction* async_update =
1728       entry_builder.AddInstruction(HloInstruction::CreateAsyncUpdate(
1729           s_tuple, async_start, async_computation.get(),
1730           /*async_group_id=*/std::nullopt,
1731           /*async_execution_thread=*/"parallel_thread"));
1732   entry_builder.AddInstruction(HloInstruction::CreateAsyncDone(
1733       s2, async_update, async_computation.get(),
1734       /*async_group_id=*/std::nullopt,
1735       /*async_execution_thread=*/"parallel_thread"));
1736 
1737   auto module = CreateNewVerifiedModule();
1738   module->AddEntryComputation(entry_builder.Build());
1739   module->AddEmbeddedComputation(std::move(async_computation));
1740 
1741   const std::string expected_with_syntax_sugar =
1742       R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1743 
1744 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1745   %p0 = f32[10]{0} parameter(0)
1746   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo"
1747   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
1748   ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo"
1749 }
1750 
1751 )";
1752   EXPECT_EQ(module->ToString(), expected_with_syntax_sugar);
1753   const std::string expected_without_syntax_sugar =
1754       R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1755 
1756 %AsyncOp (p0.1: f32[10]) -> f32[20] {
1757   %p0.1 = f32[10]{0} parameter(0)
1758   ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %p0.1), custom_call_target="foo"
1759 }, execution_thread="parallel_thread"
1760 
1761 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1762   %p0 = f32[10]{0} parameter(0)
1763   %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", calls=%AsyncOp
1764   %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", calls=%AsyncOp
1765   ROOT %async-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", calls=%AsyncOp
1766 }
1767 
1768 )";
1769   auto options = HloPrintOptions().set_syntax_sugar_async_ops(false);
1770   EXPECT_EQ(module->ToString(options), expected_without_syntax_sugar);
1771 }
1772 
TEST_F(HloInstructionTest,CanonicalStringificationFusion)1773 TEST_F(HloInstructionTest, CanonicalStringificationFusion) {
1774   // Tests stringification of a simple op, fusion, while, and conditional.
1775   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1776   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1777   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1778   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1779 
1780   HloComputation::Builder builder("TransposeDot");
1781   HloInstruction* x =
1782       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1783   HloInstruction* y =
1784       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1785   HloInstruction* reshape =
1786       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1787   DotDimensionNumbers dot_dnums;
1788   dot_dnums.add_lhs_contracting_dimensions(1);
1789   dot_dnums.add_rhs_contracting_dimensions(0);
1790   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1791       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1792 
1793   auto options = HloPrintOptions().Canonical();
1794 
1795   EXPECT_EQ(dot->ToString(options),
1796             "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
1797             "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1798 
1799   auto module = CreateNewVerifiedModule();
1800   auto* computation = module->AddEntryComputation(builder.Build());
1801   constexpr char kParallelThreadName[] = "parallel_thread";
1802   computation->SetExecutionThread(kParallelThreadName);
1803   HloInstruction* fusion = computation->CreateFusionInstruction(
1804       {dot, reshape}, HloInstruction::FusionKind::kLoop);
1805   fusion->set_called_computations_execution_thread(
1806       kParallelThreadName,
1807       /*skip_async_execution_thread_overwrite*/ false);
1808 
1809   const std::string expected_fusion =
1810       R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
1811 {
1812   tmp_0 = f32[5,10]{1,0} parameter(0)
1813   tmp_1 = f32[20,10]{1,0} parameter(1)
1814   tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1815   ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1816 }, execution_thread="parallel_thread")";
1817   EXPECT_EQ(fusion->ToString(options), expected_fusion);
1818 }
1819 
TEST_F(HloInstructionTest,CanonicalStringificationWhile)1820 TEST_F(HloInstructionTest, CanonicalStringificationWhile) {
1821   // Tests stringification of a simple op, fusion, while, and conditional.
1822   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1823   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1824   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1825   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1826 
1827   HloComputation::Builder builder("TransposeDot");
1828   HloInstruction* x =
1829       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1830   HloInstruction* y =
1831       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1832   HloInstruction* reshape =
1833       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1834   DotDimensionNumbers dot_dnums;
1835   dot_dnums.add_lhs_contracting_dimensions(1);
1836   dot_dnums.add_rhs_contracting_dimensions(0);
1837   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1838       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1839 
1840   auto module = CreateNewVerifiedModule();
1841   auto* computation = module->AddEntryComputation(builder.Build());
1842   computation->CreateFusionInstruction({dot, reshape},
1843                                        HloInstruction::FusionKind::kLoop);
1844 
1845   HloInstruction* loop = builder.AddInstruction(
1846       HloInstruction::CreateWhile(sout, computation, computation, x));
1847 
1848   auto options = HloPrintOptions().Canonical();
1849   const std::string expected_loop =
1850       R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
1851 {
1852   tmp_0 = f32[5,10]{1,0} parameter(0)
1853   tmp_1 = f32[20,10]{1,0} parameter(1)
1854   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1855   {
1856     tmp_0 = f32[5,10]{1,0} parameter(0)
1857     tmp_1 = f32[20,10]{1,0} parameter(1)
1858     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1859     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1860   }
1861 }, body=
1862 {
1863   tmp_0 = f32[5,10]{1,0} parameter(0)
1864   tmp_1 = f32[20,10]{1,0} parameter(1)
1865   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1866   {
1867     tmp_0 = f32[5,10]{1,0} parameter(0)
1868     tmp_1 = f32[20,10]{1,0} parameter(1)
1869     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1870     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1871   }
1872 })";
1873   EXPECT_EQ(loop->ToString(options), expected_loop);
1874 }
1875 
TEST_F(HloInstructionTest,CanonicalStringificationConditional)1876 TEST_F(HloInstructionTest, CanonicalStringificationConditional) {
1877   // Tests stringification of a simple op, fusion, while, and conditional.
1878   const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1879   const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1880   const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1881   const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1882 
1883   HloComputation::Builder builder("TransposeDot");
1884   HloInstruction* x =
1885       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1886   HloInstruction* y =
1887       builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1888   HloInstruction* reshape =
1889       builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1890   DotDimensionNumbers dot_dnums;
1891   dot_dnums.add_lhs_contracting_dimensions(1);
1892   dot_dnums.add_rhs_contracting_dimensions(0);
1893   HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1894       sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1895 
1896   auto module = CreateNewVerifiedModule();
1897   auto* computation = module->AddEntryComputation(builder.Build());
1898   computation->CreateFusionInstruction({dot, reshape},
1899                                        HloInstruction::FusionKind::kLoop);
1900 
1901   builder.AddInstruction(
1902       HloInstruction::CreateWhile(sout, computation, computation, x));
1903 
1904   auto pred = builder.AddInstruction(
1905       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1906   HloInstruction* conditional =
1907       builder.AddInstruction(HloInstruction::CreateConditional(
1908           sout, pred, x, computation, x, computation));
1909   auto options = HloPrintOptions().Canonical();
1910   const std::string expected_conditional =
1911       R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
1912 {
1913   tmp_0 = f32[5,10]{1,0} parameter(0)
1914   tmp_1 = f32[20,10]{1,0} parameter(1)
1915   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1916   {
1917     tmp_0 = f32[5,10]{1,0} parameter(0)
1918     tmp_1 = f32[20,10]{1,0} parameter(1)
1919     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1920     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1921   }
1922 }, false_computation=
1923 {
1924   tmp_0 = f32[5,10]{1,0} parameter(0)
1925   tmp_1 = f32[20,10]{1,0} parameter(1)
1926   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1927   {
1928     tmp_0 = f32[5,10]{1,0} parameter(0)
1929     tmp_1 = f32[20,10]{1,0} parameter(1)
1930     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1931     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1932   }
1933 })";
1934   EXPECT_EQ(conditional->ToString(options), expected_conditional);
1935 }
1936 
TEST_F(HloInstructionTest,CheckDeepClone)1937 TEST_F(HloInstructionTest, CheckDeepClone) {
1938   const char* const hlo_string = R"(
1939 HloModule Module
1940 
1941 addy (lhs: s32[], rhs: s32[]) -> s32[] {
1942   lhs = s32[] parameter(0)
1943   rhs = s32[] parameter(1)
1944   ROOT zadd = s32[] add(lhs, rhs)
1945 }
1946 
1947 calla (x: s32[]) -> s32[] {
1948   x = s32[] parameter(0)
1949   reduce = s32[] reduce-window(x, x), to_apply=addy
1950   ROOT xadd = s32[] add(x, reduce)
1951 }
1952 
1953 body (bparam: s32[]) -> s32[] {
1954   constant = s32[] constant(1)
1955   bparam = s32[] parameter(0)
1956   v = s32[] call(bparam), to_apply=calla
1957   ROOT add = s32[] add(constant, bparam)
1958 }
1959 
1960 condition (cparam: s32[]) -> pred[] {
1961   xconstant = s32[] constant(5)
1962   cparam = s32[] parameter(0)
1963   ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
1964 }
1965 
1966 ENTRY entry (param: s32[]) -> s32[] {
1967   eparam = s32[] parameter(0)
1968   ROOT while = s32[] while(eparam), condition=condition, body=body
1969  }
1970 )";
1971   // Check that deep clones really deep clones every instruction and
1972   // computations, without leaving dangling pointers to the old module.
1973   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1974                           ParseAndReturnVerifiedModule(hlo_string));
1975   std::unique_ptr<HloModule> clone = module->Clone();
1976   for (HloComputation* computation : clone->computations()) {
1977     EXPECT_EQ(computation->parent(), clone.get());
1978     for (HloInstruction* instruction : computation->instructions()) {
1979       EXPECT_EQ(instruction->parent()->parent(), clone.get());
1980     }
1981   }
1982 }
1983 
TEST_F(HloInstructionTest,IdenticalAccountsForBackendConfig)1984 TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
1985   const Shape shape = ShapeUtil::MakeShape(F32, {42});
1986   HloComputation::Builder builder("test");
1987   HloInstruction* p =
1988       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
1989 
1990   HloInstruction* add1 = builder.AddInstruction(
1991       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1992   HloInstruction* add2 = builder.AddInstruction(
1993       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1994 
1995   EXPECT_TRUE(add1->Identical(*add2));
1996   add1->set_raw_backend_config_string("abc");
1997   EXPECT_FALSE(add1->Identical(*add2));
1998 }
1999 
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallWindow)2000 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
2001   auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2002                                                  /*operands=*/{},
2003                                                  /*custom_call_target=*/"foo");
2004   auto instr2 = instr1->Clone();
2005   EXPECT_TRUE(instr1->Identical(*instr2));
2006 
2007   Window w = window_util::MakeWindow({1, 2, 3});
2008   instr1->set_window(w);
2009   EXPECT_FALSE(instr1->Identical(*instr2));
2010 }
2011 
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallDnums)2012 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
2013   auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2014                                                  /*operands=*/{},
2015                                                  /*custom_call_target=*/"foo");
2016   auto instr2 = instr1->Clone();
2017   EXPECT_TRUE(instr1->Identical(*instr2));
2018 
2019   ConvolutionDimensionNumbers dnums;
2020   dnums.set_output_batch_dimension(42);
2021   instr1->set_convolution_dimension_numbers(dnums);
2022   EXPECT_FALSE(instr1->Identical(*instr2));
2023 }
2024 
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallHasSideEffect)2025 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallHasSideEffect) {
2026   auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2027                                                  /*operands=*/{},
2028                                                  /*custom_call_target=*/"foo");
2029   auto instr2 = instr1->Clone();
2030   EXPECT_TRUE(instr1->Identical(*instr2));
2031 
2032   auto custom_call_instr1 = Cast<HloCustomCallInstruction>(instr1.get());
2033   custom_call_instr1->set_custom_call_has_side_effect(true);
2034   EXPECT_FALSE(instr1->Identical(*instr2));
2035 }
2036 
TEST_F(HloInstructionTest,CloneWindowOnCustomCall)2037 TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
2038   auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2039                                                 /*operands=*/{},
2040                                                 /*custom_call_target=*/"foo");
2041   Window w = window_util::MakeWindow({1, 2, 3});
2042   instr->set_window(w);
2043   auto clone = instr->Clone();
2044   EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
2045       << clone->window().DebugString();
2046 }
2047 
TEST_F(HloInstructionTest,CloneDnumsOnCustomCall)2048 TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
2049   auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2050                                                 /*operands=*/{},
2051                                                 /*custom_call_target=*/"foo");
2052   ConvolutionDimensionNumbers dnums;
2053   dnums.set_output_batch_dimension(42);
2054   instr->set_convolution_dimension_numbers(dnums);
2055   auto clone = instr->Clone();
2056   EXPECT_TRUE(protobuf_util::ProtobufEquals(
2057       clone->convolution_dimension_numbers(), dnums))
2058       << clone->convolution_dimension_numbers().DebugString();
2059 }
2060 
TEST_F(HloInstructionTest,CloneHasSideEffectOnCustomCall)2061 TEST_F(HloInstructionTest, CloneHasSideEffectOnCustomCall) {
2062   auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2063                                                 /*operands=*/{},
2064                                                 /*custom_call_target=*/"foo");
2065   auto custom_call_instr = Cast<HloCustomCallInstruction>(instr.get());
2066   EXPECT_FALSE(custom_call_instr->custom_call_has_side_effect());
2067   custom_call_instr->set_custom_call_has_side_effect(true);
2068   EXPECT_TRUE(custom_call_instr->custom_call_has_side_effect());
2069   auto clone = instr->Clone();
2070   auto custom_call_clone = Cast<HloCustomCallInstruction>(clone.get());
2071   EXPECT_TRUE(custom_call_clone->custom_call_has_side_effect());
2072 }
2073 
TEST_F(HloInstructionTest,CustomCallHasSideEffect)2074 TEST_F(HloInstructionTest, CustomCallHasSideEffect) {
2075   auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2076                                                 /*operands=*/{},
2077                                                 /*custom_call_target=*/"foo");
2078   auto custom_call_instr = Cast<HloCustomCallInstruction>(instr.get());
2079   EXPECT_FALSE(instr->HasSideEffect());
2080   custom_call_instr->set_custom_call_has_side_effect(true);
2081   EXPECT_TRUE(instr->HasSideEffect());
2082 }
2083 
TEST_F(HloInstructionTest,PreserveOperandPrecisionOnCloneConv)2084 TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
2085   constexpr char kHloString[] = R"(
2086   HloModule test_module
2087   ENTRY test {
2088     arg0 = f32[1,2,1] parameter(0)
2089     arg1 = f32[1,1,1] parameter(1)
2090     ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
2091       dim_labels=b0f_0io->b0f, operand_precision={high,default}
2092   })";
2093   TF_ASSERT_OK_AND_ASSIGN(auto module,
2094                           ParseAndReturnVerifiedModule(kHloString));
2095   auto* conv = module->entry_computation()->root_instruction();
2096 
2097   auto clone = conv->Clone();
2098   EXPECT_THAT(
2099       clone->precision_config().operand_precision(),
2100       ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
2101 }
2102 
TEST_F(HloInstructionTest,PreserveOuterDimensionPartitionsOnClone)2103 TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) {
2104   constexpr char kHloString[] = R"(
2105   HloModule test_module
2106   ENTRY test {
2107     ROOT iota = f32[100] iota(), iota_dimension=0, outer_dimension_partitions={0, 50}
2108   })";
2109   TF_ASSERT_OK_AND_ASSIGN(auto module,
2110                           ParseAndReturnVerifiedModule(kHloString));
2111   auto* iota = module->entry_computation()->root_instruction();
2112 
2113   auto clone = iota->Clone();
2114   EXPECT_THAT(clone->outer_dimension_partitions(),
2115               ::testing::ElementsAre(0, 50));
2116 }
2117 
TEST_F(HloInstructionTest,ReuseReshapeOfFusionParameter)2118 TEST_F(HloInstructionTest, ReuseReshapeOfFusionParameter) {
2119   // Create a fusion node which uses the reshape of a parameter twice.  Because
2120   // it's the same reshape, this counts as UseKind::kUsePermutingElements, which
2121   // is exposed publicly as "does not reuse this operand".
2122   constexpr char kHloString[] = R"(
2123   HloModule test_module
2124   f {
2125     p = f32[3,2] parameter(0)
2126     r = f32[2,3] reshape(p)
2127     x = f32[2,3] multiply(r, r)
2128     y = f32[2,3] add(r, r)
2129     ROOT sum = f32[2,3] add(x, y)
2130   }
2131   ENTRY test {
2132     p = f32[3,2] parameter(0)
2133     ROOT fusion = f32[2,3] fusion(p), calls=f, kind=kLoop
2134   })";
2135   TF_ASSERT_OK_AND_ASSIGN(auto module,
2136                           ParseAndReturnVerifiedModule(kHloString));
2137   const HloInstruction* root = module->entry_computation()->root_instruction();
2138   EXPECT_FALSE(root->ReusesOperandElements(0));
2139 }
2140 
TEST_F(HloInstructionTest,ReuseMultipleReshapesOfFusionParameter)2141 TEST_F(HloInstructionTest, ReuseMultipleReshapesOfFusionParameter) {
2142   // Create a fusion node which uses two different reshapes of a parameter
2143   // twice.  Because they're not the same reshapes, this counts as
2144   // UseKind::kUsePermutingElements, which is exposed publicly as "does reuse
2145   // this operand".
2146   constexpr char kHloString[] = R"(
2147   HloModule test_module
2148   f {
2149     p = f32[3,2] parameter(0)
2150     r1 = f32[2,3] reshape(p)
2151     r2 = f32[6,1] reshape(p)
2152     ROOT result = (f32[2,3], f32[6,1]) tuple(r1, r2)
2153   }
2154   ENTRY test {
2155     p = f32[3,2] parameter(0)
2156     ROOT fusion = (f32[2,3], f32[6,1]) fusion(p), calls=f, kind=kLoop
2157   })";
2158   TF_ASSERT_OK_AND_ASSIGN(auto module,
2159                           ParseAndReturnVerifiedModule(kHloString));
2160   const HloInstruction* root = module->entry_computation()->root_instruction();
2161   EXPECT_TRUE(root->ReusesOperandElements(0));
2162 }
2163 
TEST_F(HloInstructionTest,BitcastDoesNotReuseElements)2164 TEST_F(HloInstructionTest, BitcastDoesNotReuseElements) {
2165   constexpr char kHloString[] = R"(
2166   HloModule test_module
2167   ENTRY test {
2168     p = f32[3,2]{0,1} parameter(0)
2169     ROOT bitcast = f32[6] bitcast(p)
2170   })";
2171   TF_ASSERT_OK_AND_ASSIGN(auto module,
2172                           ParseAndReturnVerifiedModule(kHloString));
2173   const HloInstruction* root = module->entry_computation()->root_instruction();
2174   EXPECT_FALSE(root->ReusesOperandElements(0));
2175 }
2176 
TEST_F(HloInstructionTest,GatherDoesNotReuseElements)2177 TEST_F(HloInstructionTest, GatherDoesNotReuseElements) {
2178   constexpr char kHloString[] = R"(
2179   HloModule test_module
2180 
2181   ENTRY test {
2182     input = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
2183     idx = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
2184     ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}
2185       gather(input, idx), offset_dims={4,5,6,7,8}, collapsed_slice_dims={},
2186       start_index_map={0,1,2,3,4}, index_vector_dim=4,
2187       slice_sizes={30,29,28,27,26}
2188   })";
2189   TF_ASSERT_OK_AND_ASSIGN(auto module,
2190                           ParseAndReturnVerifiedModule(kHloString));
2191   const HloInstruction* root = module->entry_computation()->root_instruction();
2192   EXPECT_FALSE(root->ReusesOperandElements(0));
2193   EXPECT_FALSE(root->ReusesOperandElements(1));
2194 }
2195 
TEST_F(HloInstructionTest,BackendConfigCanContainNonFiniteFloats)2196 TEST_F(HloInstructionTest, BackendConfigCanContainNonFiniteFloats) {
2197   HloComputation::Builder b(TestName());
2198   Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2199   auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2200   DotDimensionNumbers dot_dnums;
2201   dot_dnums.add_lhs_contracting_dimensions(1);
2202   dot_dnums.add_rhs_contracting_dimensions(0);
2203   auto dot = b.AddInstruction(HloInstruction::CreateDot(
2204       shape, p0, p0, dot_dnums, DefaultPrecisionConfig(2)));
2205 
2206   gpu::GemmBackendConfig orig_config;
2207   orig_config.set_alpha_real(std::numeric_limits<double>::infinity());
2208   orig_config.set_alpha_imag(std::numeric_limits<double>::quiet_NaN());
2209   TF_ASSERT_OK(dot->set_backend_config(orig_config));
2210 
2211   TF_ASSERT_OK_AND_ASSIGN(auto new_config,
2212                           dot->backend_config<gpu::GemmBackendConfig>());
2213   EXPECT_GT(new_config.alpha_real(), std::numeric_limits<double>::max());
2214   EXPECT_NE(new_config.alpha_imag(), new_config.alpha_imag());
2215 }
2216 
2217 }  // namespace
2218 }  // namespace xla
2219