1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <optional>
19 #include <set>
20 #include <string>
21 #include <utility>
22 #include <vector>
23
24 #include "absl/container/flat_hash_map.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/protobuf_util.h"
27 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
28 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
29 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/test.h"
34 #include "tensorflow/compiler/xla/test_helpers.h"
35 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 #include "tensorflow/core/lib/core/status_test_util.h"
39
40 namespace xla {
41 namespace {
42
43 using ::testing::ElementsAre;
44 using ::testing::UnorderedElementsAre;
45
46 class HloInstructionTest : public HloTestBase {
47 protected:
48 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
49 };
50
51 // Simple visitor that collects the number of users and operands for certain HLO
52 // nodes. It also verifies some of the DFS visiting invariants (operands visited
53 // before their users, nodes not visited twice, etc.)
54 class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
55 public:
DefaultAction(HloInstruction * hlo_instruction)56 Status DefaultAction(HloInstruction* hlo_instruction) override {
57 return Unimplemented("not implemented %s",
58 HloOpcodeString(hlo_instruction->opcode()));
59 }
60
HandleParameter(HloInstruction * parameter)61 Status HandleParameter(HloInstruction* parameter) override {
62 EXPECT_FALSE(count_.contains(parameter));
63 count_[parameter] = GetCountsForNode(parameter);
64 return OkStatus();
65 }
66
HandleConstant(HloInstruction * constant)67 Status HandleConstant(HloInstruction* constant) override {
68 EXPECT_FALSE(count_.contains(constant));
69 count_[constant] = GetCountsForNode(constant);
70 return OkStatus();
71 }
72
HandleAdd(HloInstruction * add)73 Status HandleAdd(HloInstruction* add) override {
74 auto lhs = add->operand(0);
75 auto rhs = add->operand(1);
76 EXPECT_FALSE(count_.contains(add));
77 EXPECT_TRUE(count_.contains(lhs));
78 EXPECT_TRUE(count_.contains(rhs));
79 count_[add] = GetCountsForNode(add);
80 return OkStatus();
81 }
82
HandleNegate(HloInstruction * negate)83 Status HandleNegate(HloInstruction* negate) override {
84 auto operand = negate->operand(0);
85 EXPECT_FALSE(count_.contains(negate));
86 EXPECT_TRUE(count_.contains(operand));
87 count_[negate] = GetCountsForNode(negate);
88 return OkStatus();
89 }
90
HandleMap(HloInstruction * map)91 Status HandleMap(HloInstruction* map) override {
92 EXPECT_FALSE(count_.contains(map));
93 for (HloInstruction* arg : map->operands()) {
94 EXPECT_TRUE(count_.contains(arg));
95 }
96 count_[map] = GetCountsForNode(map);
97 return OkStatus();
98 }
99
HandleReduce(HloInstruction * reduce)100 Status HandleReduce(HloInstruction* reduce) override {
101 auto arg = reduce->operand(0);
102 auto init_value = reduce->operand(1);
103 EXPECT_FALSE(count_.contains(reduce));
104 EXPECT_TRUE(count_.contains(arg));
105 EXPECT_TRUE(count_.contains(init_value));
106 count_[reduce] = GetCountsForNode(reduce);
107 return OkStatus();
108 }
109
NumOperands(const HloInstruction * node)110 int64_t NumOperands(const HloInstruction* node) {
111 auto count_iterator = count_.find(node);
112 EXPECT_NE(count_.end(), count_iterator);
113 return count_iterator->second.operand_count;
114 }
115
NumUsers(const HloInstruction * node)116 int64_t NumUsers(const HloInstruction* node) {
117 auto count_iterator = count_.find(node);
118 EXPECT_NE(count_.end(), count_iterator);
119 return count_iterator->second.user_count;
120 }
121
122 private:
123 struct NumOpsAndUsers {
124 int64_t operand_count;
125 int64_t user_count;
126 };
127
128 // Helper function to count operands and users for the given HLO.
GetCountsForNode(const HloInstruction * node)129 NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
130 NumOpsAndUsers counts{node->operand_count(), node->user_count()};
131 return counts;
132 }
133
134 // Counters for HLOs. Maps HLO to a NumOpsAndUsers.
135 absl::flat_hash_map<const HloInstruction*, NumOpsAndUsers> count_;
136 };
137
TEST_F(HloInstructionTest,BasicProperties)138 TEST_F(HloInstructionTest, BasicProperties) {
139 auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
140
141 EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
142 EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
143 EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
144 EXPECT_FALSE(parameter->operand_count());
145 }
146
TEST_F(HloInstructionTest,UserWithTwoOperands)147 TEST_F(HloInstructionTest, UserWithTwoOperands) {
148 // [Param foo]-----> |-----|
149 // | Add |
150 // [Param bar]-----> |-----|
151 HloComputation::Builder builder(TestName());
152 auto foo =
153 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
154 auto bar =
155 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
156 auto add = builder.AddInstruction(
157 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
158 auto module = CreateNewVerifiedModule();
159 module->AddEntryComputation(builder.Build());
160
161 EXPECT_THAT(add->operands(), UnorderedElementsAre(foo, bar));
162 EXPECT_THAT(foo->users(), UnorderedElementsAre(add));
163 EXPECT_THAT(bar->users(), UnorderedElementsAre(add));
164
165 OpAndUserCollectingVisitor visitor;
166 ASSERT_IS_OK(add->Accept(&visitor));
167
168 EXPECT_EQ(2, visitor.NumOperands(add));
169 EXPECT_EQ(0, visitor.NumUsers(add));
170 EXPECT_EQ(1, visitor.NumUsers(foo));
171 EXPECT_EQ(1, visitor.NumUsers(bar));
172 }
173
TEST_F(HloInstructionTest,MultipleUsers)174 TEST_F(HloInstructionTest, MultipleUsers) {
175 // [Param foo]
176 // / | \
177 // / | \ [Param bar]
178 // / | \ |
179 // | | | |
180 // V V V V
181 // ------- ------- -----------
182 // | exp | | exp | | add |
183 // ------- ------- -----------
184 HloComputation::Builder builder(TestName());
185 auto foo =
186 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
187 auto bar =
188 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
189 auto exp1 = builder.AddInstruction(
190 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
191 auto exp2 = builder.AddInstruction(
192 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
193 auto add = builder.AddInstruction(
194 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
195 auto module = CreateNewVerifiedModule();
196 module->AddEntryComputation(builder.Build());
197
198 EXPECT_EQ(3, foo->user_count());
199 EXPECT_EQ(1, bar->user_count());
200 EXPECT_EQ(0, exp1->user_count());
201 EXPECT_EQ(0, exp2->user_count());
202 EXPECT_EQ(0, add->user_count());
203
204 OpAndUserCollectingVisitor visitor;
205 ASSERT_IS_OK(add->Accept(&visitor));
206
207 EXPECT_EQ(2, visitor.NumOperands(add));
208 EXPECT_EQ(3, visitor.NumUsers(foo));
209 }
210
TEST_F(HloInstructionTest,RepeatedUser)211 TEST_F(HloInstructionTest, RepeatedUser) {
212 // Here we have a user 'add' nodes that uses the same HLO in both operands.
213 // Make sure we don't count it as two distinct users.
214 //
215 // [Param foo]
216 // | |
217 // | |
218 // | |
219 // V V
220 // -------
221 // | add |
222 // -------
223 HloComputation::Builder builder(TestName());
224 auto foo =
225 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
226 auto add = builder.AddInstruction(
227 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
228 auto module = CreateNewVerifiedModule();
229 module->AddEntryComputation(builder.Build());
230
231 EXPECT_EQ(1, foo->user_count());
232
233 // But 'add' still has two operands, even if both are the same HLO.
234 EXPECT_EQ(2, add->operand_count());
235 }
236
TEST_F(HloInstructionTest,MultipleUsersAndOperands)237 TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
238 // [param0] [param1]
239 // | |
240 // | [c0] |
241 // | | |
242 // V | V
243 // ------- | -------
244 // | add | <---^---> | add |
245 // ------- -------
246 // | |
247 // \ ------- /
248 // ---->| add |<----
249 // -------
250 HloComputation::Builder builder(TestName());
251 auto param0 = builder.AddInstruction(
252 HloInstruction::CreateParameter(0, r0f32_, "param0"));
253 auto param1 = builder.AddInstruction(
254 HloInstruction::CreateParameter(1, r0f32_, "param1"));
255 auto c0 = builder.AddInstruction(
256 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
257 auto addleft = builder.AddInstruction(
258 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, c0));
259 auto addright = builder.AddInstruction(
260 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c0, param1));
261 auto addtotal = builder.AddInstruction(
262 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
263 auto module = CreateNewVerifiedModule();
264 module->AddEntryComputation(builder.Build());
265
266 OpAndUserCollectingVisitor visitor;
267 ASSERT_IS_OK(addtotal->Accept(&visitor));
268
269 EXPECT_EQ(2, visitor.NumUsers(c0));
270 EXPECT_EQ(2, visitor.NumOperands(addleft));
271 EXPECT_EQ(2, visitor.NumOperands(addright));
272 EXPECT_EQ(2, visitor.NumOperands(addtotal));
273 }
274
TEST_F(HloInstructionTest,MultipleUsersAndOperandsWithUnaryOps)275 TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
276 // [param0] [c0] [param1]
277 // | | |
278 // | V |
279 // | ------- |
280 // | | neg | |
281 // | ------- |
282 // V | V
283 // ------- | -------
284 // | add | <---^---> | add |
285 // ------- -------
286 // | |
287 // \ ------- /
288 // ---->| add |<----
289 // -------
290 // |
291 // V
292 // -------
293 // | neg |
294 // -------
295 HloComputation::Builder builder(TestName());
296 auto param0 = builder.AddInstruction(
297 HloInstruction::CreateParameter(0, r0f32_, "param0"));
298 auto param1 = builder.AddInstruction(
299 HloInstruction::CreateParameter(1, r0f32_, "param1"));
300 auto c0 = builder.AddInstruction(
301 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
302 auto neg1 = builder.AddInstruction(
303 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0));
304 auto addleft = builder.AddInstruction(
305 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0, neg1));
306 auto addright = builder.AddInstruction(
307 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, neg1, param1));
308 auto addtotal = builder.AddInstruction(
309 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, addleft, addright));
310 auto neg2 = builder.AddInstruction(
311 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal));
312 auto module = CreateNewVerifiedModule();
313 module->AddEntryComputation(builder.Build());
314
315 OpAndUserCollectingVisitor visitor;
316 ASSERT_IS_OK(neg2->Accept(&visitor));
317
318 EXPECT_EQ(1, visitor.NumUsers(c0));
319 EXPECT_EQ(2, visitor.NumUsers(neg1));
320 EXPECT_EQ(2, visitor.NumOperands(addleft));
321 EXPECT_EQ(2, visitor.NumOperands(addright));
322 EXPECT_EQ(2, visitor.NumOperands(addtotal));
323 EXPECT_EQ(1, visitor.NumOperands(neg2));
324 EXPECT_EQ(0, visitor.NumUsers(neg2));
325 }
326
TEST_F(HloInstructionTest,TrivialMap)327 TEST_F(HloInstructionTest, TrivialMap) {
328 // This tests creating a trivial x+1 map as the only operation.
329 //
330 // param0[100x10] ---> (map x+1)
331 //
332 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
333 Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
334 auto module = CreateNewVerifiedModule();
335
336 // Builds an x+1.0 computation to use in a Map.
337 auto embedded_builder = HloComputation::Builder("f32+1");
338 auto param = embedded_builder.AddInstruction(
339 HloInstruction::CreateParameter(0, r0f32, "x"));
340 auto value = embedded_builder.AddInstruction(
341 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
342 embedded_builder.AddInstruction(
343 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
344 auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
345
346 // Builds a parameter and feeds it to the map.
347 HloComputation::Builder builder(TestName());
348 auto param0 = builder.AddInstruction(
349 HloInstruction::CreateParameter(0, f32a100x10, "p"));
350 auto map = builder.AddInstruction(
351 HloInstruction::CreateMap(f32a100x10, {param0}, add_f32));
352 module->AddEntryComputation(builder.Build());
353
354 OpAndUserCollectingVisitor visitor;
355 ASSERT_IS_OK(map->Accept(&visitor));
356
357 // Check counts. We aren't walking the mapper computation yet.
358 EXPECT_EQ(1, visitor.NumUsers(param0));
359 EXPECT_EQ(0, visitor.NumUsers(map));
360 EXPECT_EQ(1, visitor.NumOperands(map));
361
362 // TODO(dehnert): Add walking and counters for the wrapped computation.
363 }
364
TEST_F(HloInstructionTest,TrivialReduce)365 TEST_F(HloInstructionTest, TrivialReduce) {
366 // This tests creating a trivial x+y reduce as the only operation.
367 //
368 // param0[100x10] ---> (reduce x+y)
369 //
370 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
371 Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
372 Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
373
374 // Builds an x+y computation to use in a Reduce.
375 auto embedded_builder = HloComputation::Builder("f32+f32");
376 auto paramx = embedded_builder.AddInstruction(
377 HloInstruction::CreateParameter(0, r0f32, "x"));
378 auto paramy = embedded_builder.AddInstruction(
379 HloInstruction::CreateParameter(1, r0f32, "y"));
380 embedded_builder.AddInstruction(
381 HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
382 auto module = CreateNewVerifiedModule();
383 auto add_f32 = module->AddEmbeddedComputation(embedded_builder.Build());
384
385 // Builds a parameter and an initial value and feeds them to the reduce.
386 HloComputation::Builder builder(TestName());
387 auto param0 = builder.AddInstruction(
388 HloInstruction::CreateParameter(0, f32a100x10, "p"));
389 auto const0 = builder.AddInstruction(
390 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
391 builder.AddInstruction(
392 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
393 auto reduce = builder.AddInstruction(
394 HloInstruction::CreateReduce(f32v100, param0, const0,
395 /*dimensions_to_reduce=*/{1}, add_f32));
396 module->AddEntryComputation(builder.Build());
397
398 OpAndUserCollectingVisitor visitor;
399 ASSERT_IS_OK(reduce->Accept(&visitor));
400
401 // Check counts. We aren't walking the reducer computation.
402 EXPECT_EQ(1, visitor.NumUsers(param0));
403 EXPECT_EQ(1, visitor.NumUsers(const0));
404 EXPECT_EQ(0, visitor.NumUsers(reduce));
405 EXPECT_EQ(2, visitor.NumOperands(reduce));
406 }
407
TEST_F(HloInstructionTest,ReplaceUseInBinaryOps)408 TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
409 // Construct a graph of a few binary ops using two different
410 // parameters. Replace one of the parameters with the other parameter in one
411 // of the instructions.
412 HloComputation::Builder builder(TestName());
413 auto foo =
414 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
415 auto bar =
416 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
417 auto add_foobar = builder.AddInstruction(
418 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
419 auto add_foofoo = builder.AddInstruction(
420 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
421 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
422 add_foobar, add_foofoo));
423 auto module = CreateNewVerifiedModule();
424 module->AddEntryComputation(builder.Build());
425
426 EXPECT_EQ(2, foo->user_count());
427 EXPECT_EQ(1, bar->user_count());
428
429 // Replace the use of foo in add_foofoo with bar.
430 ASSERT_IS_OK(foo->ReplaceUseWith(add_foofoo, bar));
431
432 EXPECT_EQ(1, foo->user_count());
433 EXPECT_EQ(2, bar->user_count());
434
435 EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
436 EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
437
438 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
439 EXPECT_THAT(add_foobar->operands(), ElementsAre(foo, bar));
440 EXPECT_THAT(add_foofoo->operands(), ElementsAre(bar, bar));
441 }
442
TEST_F(HloInstructionTest,ReplaceUseInVariadicOp)443 TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
444 // Construct a tuple containing several parameters. Replace one parameter with
445 // another in the tuple.
446 HloComputation::Builder builder(TestName());
447 auto foo =
448 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
449 auto bar =
450 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
451 auto baz =
452 builder.AddInstruction(HloInstruction::CreateParameter(2, r0f32_, "baz"));
453
454 auto tuple =
455 builder.AddInstruction(HloInstruction::CreateTuple({foo, bar, baz, foo}));
456 auto add_foobar = builder.AddInstruction(
457 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
458 auto module = CreateNewVerifiedModule();
459 module->AddEntryComputation(builder.Build());
460
461 EXPECT_EQ(2, foo->user_count());
462 EXPECT_THAT(foo->users(), UnorderedElementsAre(tuple, add_foobar));
463
464 // Replace the use of foo in tuple with bar.
465 ASSERT_IS_OK(foo->ReplaceUseWith(tuple, bar));
466
467 EXPECT_THAT(foo->users(), UnorderedElementsAre(add_foobar));
468
469 // Both uses of foo in tuple should have been replaced with bar.
470 EXPECT_THAT(tuple->operands(), ElementsAre(bar, bar, baz, bar));
471 }
472
TEST_F(HloInstructionTest,ReplaceUseInUnaryOp)473 TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
474 // Construct a couple unary instructions which use a parameter. Replace the
475 // use of a parameter in one of the unary ops with the other parameter.
476 HloComputation::Builder builder(TestName());
477 auto foo =
478 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
479 auto bar =
480 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
481
482 auto exp = builder.AddInstruction(
483 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
484 auto log = builder.AddInstruction(
485 HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
486 auto module = CreateNewVerifiedModule();
487 module->AddEntryComputation(builder.Build());
488
489 EXPECT_EQ(2, foo->user_count());
490 EXPECT_THAT(foo->users(), UnorderedElementsAre(exp, log));
491 EXPECT_EQ(0, bar->user_count());
492
493 // Replace the use of foo in exp with bar.
494 ASSERT_IS_OK(foo->ReplaceUseWith(exp, bar));
495
496 // The use of foo in log should not have been affected.
497 EXPECT_EQ(1, foo->user_count());
498 EXPECT_THAT(foo->users(), UnorderedElementsAre(log));
499 EXPECT_THAT(log->operands(), ElementsAre(foo));
500
501 // Bar should now be used in exp.
502 EXPECT_EQ(1, bar->user_count());
503 EXPECT_EQ(*bar->users().begin(), exp);
504 EXPECT_EQ(1, exp->operands().size());
505 EXPECT_EQ(*exp->operands().begin(), bar);
506 }
507
TEST_F(HloInstructionTest,ReplaceAllUsesWithInBinaryOps)508 TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
509 // Construct a simple graph of a few binary ops using two different
510 // parameters. Replace all uses of one of the parameters with the other
511 // parameter.
512 HloComputation::Builder builder(TestName());
513 auto foo =
514 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
515 auto bar =
516 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
517 auto add_foobar = builder.AddInstruction(
518 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
519 auto add_foofoo = builder.AddInstruction(
520 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, foo));
521 builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
522 add_foobar, add_foofoo));
523 auto module = CreateNewVerifiedModule();
524 module->AddEntryComputation(builder.Build());
525
526 EXPECT_EQ(2, foo->user_count());
527 EXPECT_EQ(1, bar->user_count());
528
529 // Replace all uses of foo with bar.
530 ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
531
532 EXPECT_EQ(0, foo->user_count());
533 EXPECT_EQ(2, bar->user_count());
534
535 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, add_foofoo));
536 }
537
TEST_F(HloInstructionTest,ReplaceAllUsesInMultipleOps)538 TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
539 // Construct a graph containing several ops (a unary, binary, and variadic)
540 // which use two parameters. Replace all uses of one of the parameters with
541 // the other parameter.
542 HloComputation::Builder builder(TestName());
543 auto foo =
544 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
545 auto bar =
546 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "bar"));
547
548 auto add_foobar = builder.AddInstruction(
549 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo, bar));
550 auto exp = builder.AddInstruction(
551 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
552 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({foo, bar}));
553 auto module = CreateNewVerifiedModule();
554 module->AddEntryComputation(builder.Build());
555
556 EXPECT_EQ(3, foo->user_count());
557 EXPECT_EQ(2, bar->user_count());
558
559 // Replace all uses of foo with bar.
560 ASSERT_IS_OK(foo->ReplaceAllUsesWith(bar));
561
562 EXPECT_EQ(0, foo->user_count());
563 EXPECT_EQ(3, bar->user_count());
564
565 EXPECT_THAT(bar->users(), UnorderedElementsAre(add_foobar, exp, tuple));
566 }
567
568 // Simple visitor that collects and post-processes each node in the graph.
569 class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
570 public:
NodeCollectorAndPostProcessor()571 NodeCollectorAndPostProcessor() {}
572
Postprocess(HloInstruction * hlo)573 Status Postprocess(HloInstruction* hlo) override {
574 post_processed_nodes_.push_back(hlo);
575 return OkStatus();
576 }
577
DefaultAction(HloInstruction * hlo_instruction)578 Status DefaultAction(HloInstruction* hlo_instruction) override {
579 visited_nodes_.push_back(hlo_instruction);
580 return OkStatus();
581 }
582
visited_nodes()583 const std::vector<const HloInstruction*>& visited_nodes() {
584 return visited_nodes_;
585 }
586
post_processed_nodes()587 const std::vector<const HloInstruction*>& post_processed_nodes() {
588 return post_processed_nodes_;
589 }
590
591 private:
592 std::vector<const HloInstruction*> visited_nodes_;
593 std::vector<const HloInstruction*> post_processed_nodes_;
594 };
595
596 // Returns true if "vec" contains distinct nodes.
Distinct(const std::vector<const HloInstruction * > & vec)597 bool Distinct(const std::vector<const HloInstruction*>& vec) {
598 std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
599 return distinct_nodes.size() == vec.size();
600 }
601
TEST_F(HloInstructionTest,PostProcessAllVisitedNodes)602 TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
603 // Verifies all the nodes are visited and post-processed in the same order,
604 // and that each node is visited exactly once.
605 //
606 // /--> exp --\
607 // foo add
608 // \--> log --/
609 HloComputation::Builder builder(TestName());
610 auto foo =
611 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "foo"));
612 auto exp = builder.AddInstruction(
613 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo));
614 auto log = builder.AddInstruction(
615 HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo));
616 auto add = builder.AddInstruction(
617 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp, log));
618 auto module = CreateNewVerifiedModule();
619 module->AddEntryComputation(builder.Build());
620
621 NodeCollectorAndPostProcessor visitor;
622 ASSERT_IS_OK(add->Accept(&visitor));
623 // Verifies all the nodes are visited and post-processed in the same order.
624 EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
625 // Verifies each node is visited exactly once.
626 EXPECT_TRUE(Distinct(visitor.visited_nodes()));
627 }
628
TEST_F(HloInstructionTest,SingletonFusionOp)629 TEST_F(HloInstructionTest, SingletonFusionOp) {
630 HloComputation::Builder builder(TestName());
631 // Create a fusion instruction containing a single unary operation.
632 auto constant = builder.AddInstruction(
633 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
634 auto exp = builder.AddInstruction(
635 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
636 auto module = CreateNewVerifiedModule();
637 auto* computation = module->AddEntryComputation(builder.Build());
638 auto* fusion = computation->CreateFusionInstruction(
639 {exp}, HloInstruction::FusionKind::kLoop);
640
641 EXPECT_THAT(fusion->operands(), ElementsAre(constant));
642 EXPECT_THAT(constant->users(), ElementsAre(fusion));
643 }
644
TEST_F(HloInstructionTest,BinaryFusionOp)645 TEST_F(HloInstructionTest, BinaryFusionOp) {
646 HloComputation::Builder builder(TestName());
647 // Create a fusion instruction containing a single binary operation.
648 auto constant1 = builder.AddInstruction(
649 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
650 auto constant2 = builder.AddInstruction(
651 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
652 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
653 r0f32_, HloOpcode::kAdd, constant1, constant2));
654 auto module = CreateNewVerifiedModule();
655 auto* computation = module->AddEntryComputation(builder.Build());
656 auto* fusion = computation->CreateFusionInstruction(
657 {add}, HloInstruction::FusionKind::kLoop);
658
659 EXPECT_THAT(fusion->operands(), ElementsAre(constant1, constant2));
660 EXPECT_THAT(constant1->users(), ElementsAre(fusion));
661 EXPECT_THAT(constant2->users(), ElementsAre(fusion));
662 }
663
TEST_F(HloInstructionTest,ChainFusionOp)664 TEST_F(HloInstructionTest, ChainFusionOp) {
665 HloComputation::Builder builder(TestName());
666 // Create a chain of fused unary ops.
667 auto constant = builder.AddInstruction(
668 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
669 auto exp1 = builder.AddInstruction(
670 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
671 auto exp2 = builder.AddInstruction(
672 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
673 auto exp3 = builder.AddInstruction(
674 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
675
676 auto module = CreateNewVerifiedModule();
677 auto* computation = module->AddEntryComputation(builder.Build());
678 auto* fusion = computation->CreateFusionInstruction(
679 {exp3, exp2, exp1}, HloInstruction::FusionKind::kLoop);
680
681 EXPECT_THAT(fusion->operands(), ElementsAre(constant));
682 EXPECT_THAT(constant->users(), ElementsAre(fusion));
683 }
684
TEST_F(HloInstructionTest,PreserveMetadataInFusionAndClone)685 TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) {
686 HloComputation::Builder builder(TestName());
687 // Create a chain of fused unary ops.
688 auto constant = builder.AddInstruction(
689 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
690 auto exp1 = builder.AddInstruction(
691 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
692 auto exp2 = builder.AddInstruction(
693 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
694 OpMetadata metadata;
695 metadata.set_op_name("tf_op");
696 exp1->set_metadata(metadata);
697 exp2->set_metadata(metadata);
698
699 auto module = CreateNewVerifiedModule();
700 auto* computation = module->AddEntryComputation(builder.Build());
701 auto* fusion = computation->CreateFusionInstruction(
702 {exp2, exp1}, HloInstruction::FusionKind::kLoop);
703
704 EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
705 EXPECT_TRUE(protobuf_util::ProtobufEquals(
706 metadata, fusion->fused_expression_root()->metadata()));
707 EXPECT_TRUE(protobuf_util::ProtobufEquals(
708 metadata, fusion->fused_expression_root()->operand(0)->metadata()));
709
710 auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {});
711 EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata()));
712 }
713
TEST_F(HloInstructionTest,BinaryCallOp)714 TEST_F(HloInstructionTest, BinaryCallOp) {
715 HloComputation::Builder builder(TestName());
716 // Create a call instruction containing a single binary operation.
717 auto constant1 = builder.AddInstruction(
718 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
719 auto constant2 = builder.AddInstruction(
720 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
721 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
722 r0f32_, HloOpcode::kAdd, constant1, constant2));
723 auto module = CreateNewVerifiedModule();
724 auto* computation = module->AddEntryComputation(builder.Build());
725 auto* call = computation->CreateCallInstruction({add});
726
727 EXPECT_THAT(call->operands(), ElementsAre(constant1, constant2));
728 EXPECT_THAT(constant1->users(), ElementsAre(call));
729 EXPECT_THAT(constant2->users(), ElementsAre(call));
730 }
731
TEST_F(HloInstructionTest,ChainCallOp)732 TEST_F(HloInstructionTest, ChainCallOp) {
733 HloComputation::Builder builder(TestName());
734 // Create a chain of called unary ops.
735 auto constant = builder.AddInstruction(
736 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
737 auto exp1 = builder.AddInstruction(
738 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
739 auto exp2 = builder.AddInstruction(
740 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
741 auto exp3 = builder.AddInstruction(
742 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
743
744 auto module = CreateNewVerifiedModule();
745 auto* computation = module->AddEntryComputation(builder.Build());
746 auto* call = computation->CreateCallInstruction({exp3, exp2, exp1});
747
748 EXPECT_THAT(call->operands(), ElementsAre(constant));
749 EXPECT_THAT(constant->users(), ElementsAre(call));
750 }
751
TEST_F(HloInstructionTest,MultiOutputCallOp)752 TEST_F(HloInstructionTest, MultiOutputCallOp) {
753 HloComputation::Builder builder(TestName());
754 // Create a chain of called unary ops.
755 auto constant = builder.AddInstruction(
756 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
757 auto exp1 = builder.AddInstruction(
758 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
759 auto exp2 = builder.AddInstruction(
760 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1));
761 auto exp3 = builder.AddInstruction(
762 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2));
763 auto exp4 = builder.AddInstruction(
764 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant));
765 auto add = builder.AddInstruction(
766 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp3, exp4));
767
768 auto module = CreateNewVerifiedModule();
769 auto* computation = module->AddEntryComputation(builder.Build());
770 auto* call = computation->CreateCallInstruction({exp3, exp2, exp1});
771 call->AppendInstructionIntoCalledComputation(exp4, /*add_output=*/true);
772
773 EXPECT_THAT(call->operands(), ElementsAre(constant));
774 EXPECT_EQ(add->operand(0)->opcode(), HloOpcode::kGetTupleElement);
775 EXPECT_THAT(add->operand(0)->operands(), ElementsAre(call));
776 EXPECT_EQ(add->operand(1)->opcode(), HloOpcode::kGetTupleElement);
777 EXPECT_THAT(add->operand(1)->operands(), ElementsAre(call));
778 }
779
TEST_F(HloInstructionTest,AsyncOp)780 TEST_F(HloInstructionTest, AsyncOp) {
781 HloComputation::Builder builder(TestName());
782 // Create a call instruction containing a single binary operation.
783 auto constant1 = builder.AddInstruction(
784 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
785 auto constant2 = builder.AddInstruction(
786 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f)));
787 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
788 r0f32_, HloOpcode::kAdd, constant1, constant2));
789 auto module = CreateNewVerifiedModule();
790 auto* computation = module->AddEntryComputation(builder.Build());
791 TF_ASSERT_OK_AND_ASSIGN(
792 auto* async_done,
793 computation->CreateAsyncInstructions(
794 add, {ShapeUtil::MakeScalarShape(U32)}, "parallel_thread"));
795 auto* async_start = async_done->operand(0);
796
797 EXPECT_EQ(async_start->shape().tuple_shapes_size(), 3);
798 EXPECT_EQ(async_start->async_execution_thread(), "parallel_thread");
799 EXPECT_EQ(async_done->async_execution_thread(), "parallel_thread");
800 EXPECT_TRUE(ShapeUtil::Equal(async_start->shape().tuple_shapes(2),
801 ShapeUtil::MakeScalarShape(U32)));
802 EXPECT_EQ(async_start->async_wrapped_computation()->execution_thread(),
803 "parallel_thread");
804 EXPECT_EQ(async_done->async_wrapped_computation()->execution_thread(),
805 "parallel_thread");
806 EXPECT_THAT(async_start->operands(), ElementsAre(constant1, constant2));
807 EXPECT_THAT(constant1->users(), ElementsAre(async_start));
808 EXPECT_THAT(constant2->users(), ElementsAre(async_start));
809 EXPECT_EQ(computation->root_instruction(), async_done);
810 }
811
TEST_F(HloInstructionTest,PreserveOutfeedShapeThroughClone)812 TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) {
813 HloComputation::Builder builder(TestName());
814 auto constant = builder.AddInstruction(
815 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
816 {1, 2},
817 {3, 4},
818 })));
819 auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
820 auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
821 auto token = builder.AddInstruction(HloInstruction::CreateToken());
822 auto outfeed10 = builder.AddInstruction(
823 HloInstruction::CreateOutfeed(shape10, constant, token, ""));
824 auto outfeed01 = builder.AddInstruction(
825 HloInstruction::CreateOutfeed(shape01, constant, token, ""));
826
827 auto clone01 = builder.AddInstruction(outfeed01->Clone());
828 auto clone10 = builder.AddInstruction(outfeed10->Clone());
829
830 EXPECT_TRUE(ShapeUtil::Equal(clone01->outfeed_shape(), shape01));
831 EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10));
832 }
833
TEST_F(HloInstructionTest,PreserveTupleShapeThroughClone)834 TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) {
835 HloComputation::Builder builder(TestName());
836 auto* constant = builder.AddInstruction(
837 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
838 {1, 2},
839 {3, 4},
840 })));
841 auto* tuple =
842 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
843 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0})
844 ->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
845 *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1})
846 ->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
847 auto tuple_clone = tuple->Clone();
848 EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape()));
849 }
850
TEST_F(HloInstructionTest,PreserveShardingThroughCompatibleClone)851 TEST_F(HloInstructionTest, PreserveShardingThroughCompatibleClone) {
852 HloSharding sharding = HloSharding::AssignDevice(5);
853 HloComputation::Builder builder(TestName());
854 auto* constant = builder.AddInstruction(
855 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
856 {1, 2},
857 {3, 4},
858 })));
859 auto* tuple =
860 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
861 tuple->set_sharding(sharding);
862 // Compatible with original shape as tuple tree structure and leaf ranks are
863 // identical
864 auto clone_shape = ShapeUtil::MakeShape(F32, {3, 3});
865 clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
866 auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
867 EXPECT_EQ(tuple_clone->sharding(), sharding);
868 }
869
TEST_F(HloInstructionTest,DoNotPreserveShardingThroughTupleTreeIncompatibleClone)870 TEST_F(HloInstructionTest,
871 DoNotPreserveShardingThroughTupleTreeIncompatibleClone) {
872 HloSharding sharding = HloSharding::AssignDevice(5);
873 HloComputation::Builder builder(TestName());
874 auto* constant = builder.AddInstruction(
875 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
876 {1, 2},
877 {3, 4},
878 })));
879 auto* tuple =
880 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
881 tuple->set_sharding(sharding);
882 // Incompatible with original shape as tuple tree structure is different
883 auto clone_shape = ShapeUtil::MakeShape(F32, {2, 2});
884 clone_shape =
885 ShapeUtil::MakeTupleShape({clone_shape, clone_shape, clone_shape});
886 auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
887 EXPECT_FALSE(tuple_clone->has_sharding());
888 }
889
TEST_F(HloInstructionTest,DoNotPreserveShardingThroughLeafRankIncompatibleClone)890 TEST_F(HloInstructionTest,
891 DoNotPreserveShardingThroughLeafRankIncompatibleClone) {
892 HloSharding sharding = HloSharding::AssignDevice(5);
893 HloComputation::Builder builder(TestName());
894 auto* constant = builder.AddInstruction(
895 HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({
896 {1, 2},
897 {3, 4},
898 })));
899 auto* tuple =
900 builder.AddInstruction(HloInstruction::CreateTuple({constant, constant}));
901 tuple->set_sharding(sharding);
902 // Incompatible with original shape as tuple tree structure is different
903 auto clone_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
904 clone_shape = ShapeUtil::MakeTupleShape({clone_shape, clone_shape});
905 auto tuple_clone = tuple->CloneWithNewOperands(clone_shape, {});
906 EXPECT_FALSE(tuple_clone->has_sharding());
907 }
908
TEST_F(HloInstructionTest,FusionOpWithCalledComputations)909 TEST_F(HloInstructionTest, FusionOpWithCalledComputations) {
910 // Create a fusion instruction containing a single unary operation.
911 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
912 auto module = CreateNewVerifiedModule();
913
914 auto make_map_computation = [&]() {
915 auto builder = HloComputation::Builder("FusionMap");
916 builder.AddInstruction(
917 HloInstruction::CreateParameter(0, scalar_shape, "param"));
918 return module->AddEmbeddedComputation(builder.Build());
919 };
920
921 HloComputation* computation_x = make_map_computation();
922 HloComputation* computation_y = make_map_computation();
923
924 HloComputation::Builder builder(TestName());
925 auto constant = builder.AddInstruction(
926 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
927 auto map_1_x = builder.AddInstruction(
928 HloInstruction::CreateMap(scalar_shape, {constant}, computation_x));
929 auto map_2_x = builder.AddInstruction(
930 HloInstruction::CreateMap(scalar_shape, {map_1_x}, computation_x));
931 auto map_3_y = builder.AddInstruction(
932 HloInstruction::CreateMap(scalar_shape, {map_2_x}, computation_y));
933 auto* computation = module->AddEntryComputation(builder.Build());
934
935 auto* fusion = computation->CreateFusionInstruction(
936 {map_3_y}, HloInstruction::FusionKind::kLoop);
937 auto* fused_computation = fusion->fused_instructions_computation();
938 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
939
940 fusion->FuseInstruction(map_2_x);
941 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
942
943 fusion->FuseInstruction(map_1_x);
944 EXPECT_THAT(fusion->called_computations(), ElementsAre(fused_computation));
945 }
946
TEST_F(HloInstructionTest,ComplexFusionOp)947 TEST_F(HloInstructionTest, ComplexFusionOp) {
948 HloComputation::Builder builder(TestName());
949 // Fuse all instructions in complicated expression:
950 //
951 // add = Add(C1, C2)
952 // clamp = Clamp(C2, add, add)
953 // exp = Exp(add)
954 // mul = Mul(exp, C3)
955 // sub = Sub(mul, clamp)
956 // tuple = Tuple({sub, sub, mul, C1})
957 //
958 // Notable complexities are repeated operands in the same instruction,
959 // different shapes, use of value in different expressions.
960 auto c1 = builder.AddInstruction(
961 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f)));
962 auto c2 = builder.AddInstruction(
963 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f)));
964 auto c3 = builder.AddInstruction(
965 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f)));
966
967 auto add = builder.AddInstruction(
968 HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1, c2));
969 auto clamp = builder.AddInstruction(
970 HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp, c2, add, add));
971 auto exp = builder.AddInstruction(
972 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add));
973 auto mul = builder.AddInstruction(
974 HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply, exp, c3));
975 auto sub = builder.AddInstruction(
976 HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract, mul, clamp));
977 auto tuple =
978 builder.AddInstruction(HloInstruction::CreateTuple({sub, sub, mul, c1}));
979
980 auto module = CreateNewVerifiedModule();
981 auto* computation = module->AddEntryComputation(builder.Build());
982 auto* fusion = computation->CreateFusionInstruction(
983 {tuple, sub, mul, exp, clamp, add}, HloInstruction::FusionKind::kLoop);
984
985 // Operands in the fusion instruction's operands() vector should be in the
986 // order in which their users were added fused.
987 EXPECT_THAT(fusion->operands(), ElementsAre(c1, c3, c2));
988 EXPECT_THAT(c1->users(), ElementsAre(fusion));
989 }
990
991 // Convenience function for comparing two HloInstructions.
Identical(const HloInstruction & instruction1,const HloInstruction & instruction2)992 static bool Identical(const HloInstruction& instruction1,
993 const HloInstruction& instruction2) {
994 // Verify Identical is reflexive for both instructions.
995 EXPECT_TRUE(instruction1.Identical(instruction1));
996 EXPECT_TRUE(instruction2.Identical(instruction2));
997
998 bool is_equal = instruction1.Identical(instruction2);
999 // Verify Identical is symmetric.
1000 EXPECT_EQ(is_equal, instruction2.Identical(instruction1));
1001 return is_equal;
1002 }
1003
1004 // Convenience function for comparing two HloInstructions for structural
1005 // equality.
StructuralEqual(const HloInstruction & instruction1,const HloInstruction & instruction2)1006 static bool StructuralEqual(const HloInstruction& instruction1,
1007 const HloInstruction& instruction2) {
1008 auto eq_operand_shapes = [](const HloInstruction* a,
1009 const HloInstruction* b) {
1010 return ShapeUtil::Equal(a->shape(), b->shape());
1011 };
1012 auto eq_computations = [](const HloComputation* a, const HloComputation* b) {
1013 return *a == *b;
1014 };
1015
1016 // Verify Identical is reflexive for both instructions.
1017 EXPECT_TRUE(
1018 instruction1.Identical(instruction1, eq_operand_shapes, eq_computations));
1019 EXPECT_TRUE(
1020 instruction2.Identical(instruction2, eq_operand_shapes, eq_computations));
1021
1022 bool is_equal =
1023 instruction1.Identical(instruction2, eq_operand_shapes, eq_computations);
1024 // Verify Identical is symmetric.
1025 EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes,
1026 eq_computations));
1027 return is_equal;
1028 }
1029
TEST_F(HloInstructionTest,IdenticalInstructions)1030 TEST_F(HloInstructionTest, IdenticalInstructions) {
1031 // Test HloInstruction::Identical with some subset of instructions types.
1032
1033 // Create a set of random constant operands to use below. Make them matrices
1034 // so dimensions are interesting.
1035 auto operand1 = HloInstruction::CreateConstant(
1036 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
1037 auto operand2 = HloInstruction::CreateConstant(
1038 LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
1039 auto vector_operand = HloInstruction::CreateConstant(
1040 LiteralUtil::CreateR1<float>({42.0, 123.0}));
1041 Shape shape = operand1->shape();
1042
1043 // Convenient short names for the operands.
1044 HloInstruction* op1 = operand1.get();
1045 HloInstruction* op2 = operand2.get();
1046
1047 // Operations which only depend on their operands and opcode.
1048 EXPECT_TRUE(
1049 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
1050 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
1051 EXPECT_FALSE(
1052 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
1053 *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
1054 EXPECT_FALSE(
1055 Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
1056 *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
1057
1058 // Tuples.
1059 EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}),
1060 *HloInstruction::CreateTuple({op1, op2})));
1061 EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}),
1062 *HloInstruction::CreateTuple({op2, op1})));
1063
1064 // Broadcasts.
1065 EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
1066 *HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
1067 EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
1068 *HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
1069 Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
1070 Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
1071 EXPECT_FALSE(
1072 Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
1073 *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
1074
1075 // Binary operands.
1076 EXPECT_TRUE(Identical(
1077 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
1078 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
1079 EXPECT_FALSE(Identical(
1080 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
1081 *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
1082 EXPECT_FALSE(Identical(
1083 *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
1084 *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
1085 }
1086
TEST_F(HloInstructionTest,IdenticalCallInstructions)1087 TEST_F(HloInstructionTest, IdenticalCallInstructions) {
1088 const char* const hlo_string = R"(
1089 HloModule Module
1090
1091 subcomp1 (x: f32[]) -> f32[] {
1092 x = f32[] parameter(0)
1093 ROOT n = f32[] sine(x)
1094 }
1095
1096 subcomp2 (x: f32[]) -> f32[] {
1097 x = f32[] parameter(0)
1098 ROOT n = f32[] cosine(x)
1099 }
1100
1101 ENTRY entry (param: f32[]) -> (f32[], f32[], f32[]) {
1102 p = f32[] parameter(0)
1103 t1 = f32[] call(p), to_apply=subcomp1
1104 t2 = f32[] call(p), to_apply=subcomp1
1105 t3 = f32[] call(p), to_apply=subcomp2
1106 ROOT t = (f32[], f32[], f32[]) tuple(t1, t2, t3)
1107 }
1108 )";
1109 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1110 ParseAndReturnVerifiedModule(hlo_string));
1111
1112 auto* root = module->entry_computation()->root_instruction();
1113 auto* t1 = root->operand(0);
1114 auto* t2 = root->operand(1);
1115 auto* t3 = root->operand(2);
1116
1117 EXPECT_TRUE(StructuralEqual(*t1, *t2));
1118 EXPECT_FALSE(StructuralEqual(*t1, *t3));
1119 }
1120
TEST_F(HloInstructionTest,FunctionVisitor)1121 TEST_F(HloInstructionTest, FunctionVisitor) {
1122 // Verify the function visitor HloInstruction::Accept visits all instructions
1123 // from a root properly given the following graph:
1124 //
1125 // param
1126 // / \
1127 // negate exp
1128 // \ /
1129 // add
1130 const Shape f32 = ShapeUtil::MakeShape(F32, {});
1131 HloComputation::Builder builder(TestName());
1132 auto param =
1133 builder.AddInstruction(HloInstruction::CreateParameter(0, f32, "0"));
1134 auto negate = builder.AddInstruction(
1135 HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param));
1136 auto exp = builder.AddInstruction(
1137 HloInstruction::CreateUnary(f32, HloOpcode::kExp, param));
1138 auto add = builder.AddInstruction(
1139 HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate, exp));
1140 auto module = CreateNewVerifiedModule();
1141 module->AddEntryComputation(builder.Build());
1142
1143 int visit_num = 0;
1144 absl::flat_hash_map<HloInstruction*, int> visit_order;
1145 FunctionVisitor visitor([&visit_num, &visit_order](HloInstruction* inst) {
1146 EXPECT_FALSE(visit_order.contains(inst));
1147 visit_order[inst] = visit_num;
1148 visit_num++;
1149 return OkStatus();
1150 });
1151 EXPECT_IS_OK(add->Accept(&visitor));
1152
1153 EXPECT_EQ(0, visit_order.at(param));
1154 // negate and exp can be visited in an arbitrary order.
1155 EXPECT_TRUE(visit_order.at(exp) == 1 || visit_order.at(exp) == 2);
1156 EXPECT_TRUE(visit_order.at(negate) == 1 || visit_order.at(negate) == 2);
1157 EXPECT_NE(visit_order.at(exp), visit_order.at(negate));
1158 EXPECT_EQ(3, visit_order.at(add));
1159 }
1160
TEST_F(HloInstructionTest,FullyElementwise)1161 TEST_F(HloInstructionTest, FullyElementwise) {
1162 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1163 HloComputation::Builder builder(TestName());
1164 auto x =
1165 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1166 auto y =
1167 builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
1168 auto add = builder.AddInstruction(
1169 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x, y));
1170 auto module = CreateNewVerifiedModule();
1171 module->AddEntryComputation(builder.Build());
1172
1173 EXPECT_TRUE(add->IsElementwise());
1174 for (int i = 0; i < add->operand_count(); ++i) {
1175 EXPECT_TRUE(add->IsElementwiseOnOperand(i));
1176 }
1177 }
1178
TEST_F(HloInstructionTest,MapIsElementwise)1179 TEST_F(HloInstructionTest, MapIsElementwise) {
1180 auto module = CreateNewVerifiedModule();
1181 const Shape r2f32 = ShapeUtil::MakeShapeWithLayout(F32, {10, 10}, {1, 0});
1182 HloComputation::Builder builder(TestName());
1183 HloComputation::Builder map_builder("id");
1184 map_builder.AddInstruction(
1185 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0"));
1186 auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
1187 auto x =
1188 builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "x"));
1189 auto map = builder.AddInstruction(
1190 HloInstruction::CreateMap(r2f32, {x}, map_computation));
1191 module->AddEntryComputation(builder.Build());
1192
1193 EXPECT_TRUE(map->IsElementwise());
1194 }
1195
TEST_F(HloInstructionTest,PartiallyElementwise)1196 TEST_F(HloInstructionTest, PartiallyElementwise) {
1197 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1198 const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
1199
1200 // Fused expression:
1201 //
1202 // p0 p1 p2 p3
1203 // \ / / |
1204 // mul / |
1205 // \ / |
1206 // div broadcast
1207 // \ /
1208 // max
1209 //
1210 // The fusion instruction is not elementwise on p3 because the broadcast is
1211 // not elementwise.
1212 HloComputation::Builder builder("PartiallyElementwise");
1213 HloInstruction* p0 =
1214 builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
1215 HloInstruction* p1 =
1216 builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
1217 HloInstruction* p2 =
1218 builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
1219 HloInstruction* p3 =
1220 builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
1221 HloInstruction* mul = builder.AddInstruction(
1222 HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
1223 HloInstruction* div = builder.AddInstruction(
1224 HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
1225 // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
1226 HloInstruction* broadcast =
1227 builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
1228 HloInstruction* max = builder.AddInstruction(
1229 HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
1230
1231 auto module = CreateNewVerifiedModule();
1232 auto* computation = module->AddEntryComputation(builder.Build());
1233 HloInstruction* fusion = computation->CreateFusionInstruction(
1234 {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
1235 EXPECT_FALSE(fusion->IsElementwise());
1236 for (int64_t operand_idx = 0; operand_idx < fusion->operand_count();
1237 ++operand_idx) {
1238 const HloInstruction* operand = fusion->operand(operand_idx);
1239 if (operand == p3) {
1240 EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1241 } else {
1242 EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1243 }
1244 }
1245 }
1246
TEST_F(HloInstructionTest,PartiallyElementwiseWithReuse)1247 TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
1248 // Fused expression:
1249 // y
1250 // /
1251 // x broadcast
1252 // \ / |
1253 // min |
1254 // \ /
1255 // sub
1256 //
1257 const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
1258 const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
1259
1260 HloComputation::Builder builder("PartiallyElementwiseWithReuse");
1261 HloInstruction* x =
1262 builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
1263 HloInstruction* y =
1264 builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
1265 HloInstruction* broadcast =
1266 builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {}));
1267 HloInstruction* min = builder.AddInstruction(
1268 HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, broadcast));
1269 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
1270 r1f32, HloOpcode::kSubtract, min, broadcast));
1271
1272 auto module = CreateNewVerifiedModule();
1273 auto* computation = module->AddEntryComputation(builder.Build());
1274 HloInstruction* fusion = computation->CreateFusionInstruction(
1275 {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
1276 EXPECT_FALSE(fusion->IsElementwise());
1277 for (int64_t operand_idx = 0; operand_idx < fusion->operand_count();
1278 ++operand_idx) {
1279 if (fusion->operand(operand_idx) == y) {
1280 EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
1281 } else {
1282 EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
1283 }
1284 }
1285 }
1286
TEST_F(HloInstructionTest,CloneOfFusionPreservesShape)1287 TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
1288 // Fused expression:
1289 //
1290 // x y
1291 // | |
1292 // | transpose
1293 // \ /
1294 // dot
1295 //
1296 // Tests that shapes aren't mangled by Clone().
1297 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1298 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1299 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1300 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1301
1302 HloComputation::Builder builder("TransposeDot");
1303 HloInstruction* x =
1304 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1305 HloInstruction* y =
1306 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1307 HloInstruction* reshape =
1308 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1309 DotDimensionNumbers dot_dnums;
1310 dot_dnums.add_lhs_contracting_dimensions(1);
1311 dot_dnums.add_rhs_contracting_dimensions(0);
1312 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1313 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1314
1315 auto module = CreateNewVerifiedModule();
1316 auto* computation = module->AddEntryComputation(builder.Build());
1317 HloInstruction* fusion = computation->CreateFusionInstruction(
1318 {dot, reshape}, HloInstruction::FusionKind::kLoop);
1319
1320 auto fusion2 = fusion->Clone();
1321 const HloInstruction* root = fusion->fused_expression_root();
1322 const HloInstruction* root2 = fusion2->fused_expression_root();
1323 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
1324 EXPECT_TRUE(
1325 ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
1326 EXPECT_TRUE(
1327 ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
1328 EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
1329 root2->operand(1)->operand(0)->shape()));
1330 EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
1331 }
1332
TEST_F(HloInstructionTest,FuseInstructionKeepsInstruction)1333 TEST_F(HloInstructionTest, FuseInstructionKeepsInstruction) {
1334 constexpr char kHloString[] = R"(
1335 HloModule test_module
1336 fused_add {
1337 p0 = f32[32,32]{1,0} parameter(0)
1338 p1 = f32[32,32]{1,0} parameter(1)
1339 ROOT add = f32[32,32]{1,0} add(p0, p1)
1340 }
1341
1342 ENTRY reduce {
1343 p2 = f32[32,32]{1,0} parameter(0)
1344 p3 = f32[32,32]{1,0} parameter(1)
1345 c1 = f32[] constant(1)
1346 broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
1347 mul = f32[32,32]{1,0} multiply(p2, p3)
1348 ROOT add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
1349 })";
1350 TF_ASSERT_OK_AND_ASSIGN(auto module,
1351 ParseAndReturnVerifiedModule(kHloString));
1352 HloInstruction* fused_add = module->entry_computation()->root_instruction();
1353 HloInstruction* mul = fused_add->mutable_operand(0);
1354 EXPECT_EQ(1, mul->user_count());
1355 fused_add->FuseInstruction(mul);
1356 EXPECT_EQ(0, mul->user_count());
1357 // The fused instruction is still present in the computation.
1358 EXPECT_EQ(fused_add->parent(), mul->parent());
1359 }
1360
TEST_F(HloInstructionTest,FuseInstructionIntoMultiOutputKeepsInstruction)1361 TEST_F(HloInstructionTest, FuseInstructionIntoMultiOutputKeepsInstruction) {
1362 constexpr char kHloString[] = R"(
1363 HloModule test_module
1364 fused_add {
1365 p0 = f32[32,32]{1,0} parameter(0)
1366 p1 = f32[32,32]{1,0} parameter(1)
1367 ROOT add = f32[32,32]{1,0} add(p0, p1)
1368 }
1369
1370 ENTRY reduce {
1371 p2 = f32[32,32]{1,0} parameter(0)
1372 p3 = f32[32,32]{1,0} parameter(1)
1373 c1 = f32[] constant(1)
1374 mul = f32[32,32]{1,0} multiply(p2, p3)
1375 broadcast = f32[32,32]{1,0} broadcast(c1), dimensions={}
1376 add = f32[32,32]{1,0} fusion(mul, broadcast), kind=kLoop, calls=fused_add
1377 ROOT root = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(mul, add)
1378 })";
1379 TF_ASSERT_OK_AND_ASSIGN(auto module,
1380 ParseAndReturnVerifiedModule(kHloString));
1381 HloInstruction* root = module->entry_computation()->root_instruction();
1382 HloInstruction* mul = root->mutable_operand(0);
1383 HloInstruction* fused_add = root->mutable_operand(1);
1384 EXPECT_EQ(2, mul->user_count());
1385 fused_add->FuseInstructionIntoMultiOutput(mul);
1386 EXPECT_EQ(0, mul->user_count());
1387 // The fused instruction is still present in the computation.
1388 EXPECT_EQ(root->parent(), mul->parent());
1389 }
1390
TEST_F(HloInstructionTest,NoRedundantFusionOperandsAfterReplacingUse)1391 TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
1392 // Fused expression:
1393 //
1394 // x y
1395 // | |
1396 // | transpose
1397 // \ /
1398 // dot
1399 const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
1400
1401 HloComputation::Builder builder("TransposeDot");
1402 HloInstruction* x =
1403 builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
1404 HloInstruction* y =
1405 builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
1406 HloInstruction* reshape =
1407 builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
1408 DotDimensionNumbers dot_dnums;
1409 dot_dnums.add_lhs_contracting_dimensions(1);
1410 dot_dnums.add_rhs_contracting_dimensions(0);
1411 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1412 s, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1413
1414 auto module = CreateNewVerifiedModule();
1415 auto* computation = module->AddEntryComputation(builder.Build());
1416 HloInstruction* fusion = computation->CreateFusionInstruction(
1417 {dot, reshape}, HloInstruction::FusionKind::kLoop);
1418
1419 EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
1420
1421 EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
1422 EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
1423 }
1424
TEST_F(HloInstructionTest,FusionEquality)1425 TEST_F(HloInstructionTest, FusionEquality) {
1426 auto module = CreateNewVerifiedModule();
1427 HloComputation::Builder builder(TestName());
1428
1429 // Create two fusion instructions containing a single unary operation.
1430 auto parameter =
1431 builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
1432 auto exp = builder.AddInstruction(
1433 HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter));
1434 auto neg = builder.AddInstruction(
1435 HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter));
1436 auto* computation = module->AddEntryComputation(builder.Build());
1437 auto* fusion = computation->CreateFusionInstruction(
1438 {exp}, HloInstruction::FusionKind::kLoop);
1439 auto* fusion2 = computation->CreateFusionInstruction(
1440 {neg}, HloInstruction::FusionKind::kLoop);
1441 EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1442
1443 auto clone = fusion->Clone();
1444 EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1445 }
1446
TEST_F(HloInstructionTest,NestedFusionEquality)1447 TEST_F(HloInstructionTest, NestedFusionEquality) {
1448 auto module = CreateNewVerifiedModule();
1449 HloComputation::Builder builder(TestName());
1450
1451 // Build a nested fusion computation.
1452 Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2});
1453 auto a = builder.AddInstruction(HloInstruction::CreateConstant(
1454 LiteralUtil::CreateR2<float>({{1.0, 0.0}, {0.0, 1.0}})));
1455 auto b = builder.AddInstruction(HloInstruction::CreateConstant(
1456 LiteralUtil::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
1457 auto b_t = builder.AddInstruction(
1458 HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
1459 DotDimensionNumbers dot_dnums;
1460 dot_dnums.add_lhs_contracting_dimensions(1);
1461 dot_dnums.add_rhs_contracting_dimensions(0);
1462 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
1463 data_shape, a, b_t, dot_dnums, DefaultPrecisionConfig(2)));
1464 auto one = builder.AddInstruction(
1465 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
1466 auto add_operand = builder.AddInstruction(
1467 HloInstruction::CreateBroadcast(data_shape, one, {}));
1468 auto add = builder.AddInstruction(HloInstruction::CreateBinary(
1469 data_shape, HloOpcode::kAdd, dot, add_operand));
1470 auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
1471 data_shape, HloOpcode::kSubtract, dot, add_operand));
1472 builder.AddInstruction(
1473 HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub));
1474 auto computation = module->AddEntryComputation(builder.Build());
1475
1476 auto nested_fusion = computation->CreateFusionInstruction(
1477 {dot, b_t}, HloInstruction::FusionKind::kLoop);
1478
1479 auto fusion = computation->CreateFusionInstruction(
1480 {add, nested_fusion}, HloInstruction::FusionKind::kOutput);
1481 auto fusion2 = computation->CreateFusionInstruction(
1482 {sub, nested_fusion}, HloInstruction::FusionKind::kOutput);
1483 auto clone = fusion->Clone();
1484 EXPECT_TRUE(StructuralEqual(*fusion, *clone));
1485 EXPECT_FALSE(StructuralEqual(*fusion, *fusion2));
1486 }
1487
TEST_F(HloInstructionTest,CloneSuffixNames)1488 TEST_F(HloInstructionTest, CloneSuffixNames) {
1489 // Test that the suffix string added to cloned instructions is not
1490 // duplicated. Rather a numeric incrementing value should be appended. That
1491 // is, we want "foo.clone2", not "foo.clone.clone".
1492
1493 // Test cloning the same instruction multiple times.
1494 auto foo =
1495 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "foo");
1496 EXPECT_EQ(foo->Clone()->name(), "foo.clone");
1497 EXPECT_EQ(foo->Clone()->Clone()->name(), "foo.clone2");
1498 EXPECT_EQ(foo->Clone()->Clone()->Clone()->name(), "foo.clone3");
1499
1500 // Test custom suffixes.
1501 EXPECT_EQ(foo->Clone("bar")->name(), "foo.bar");
1502 EXPECT_EQ(foo->Clone("bar")->Clone("bar")->name(), "foo.bar2");
1503 EXPECT_EQ(foo->Clone("bar")->Clone("bar")->Clone()->name(), "foo.bar2.clone");
1504
1505 // Test instruction name with a dot.
1506 auto foo_baz = HloInstruction::CreateParameter(
1507 0, ShapeUtil::MakeShape(F32, {}), "foo.baz");
1508 EXPECT_EQ(foo_baz->Clone()->name(), "foo.baz.clone");
1509
1510 // Test incrementing a large number after the suffix.
1511 auto foo_clone234 = HloInstruction::CreateParameter(
1512 0, ShapeUtil::MakeShape(F32, {}), "foo.clone234");
1513 EXPECT_EQ(foo_clone234->Clone()->name(), "foo.clone235");
1514
1515 // Test a non-numeric string after the cloning suffix.
1516 auto foo_clonexyz = HloInstruction::CreateParameter(
1517 0, ShapeUtil::MakeShape(F32, {}), "foo.clonexyz");
1518 EXPECT_EQ(foo_clonexyz->Clone()->name(), "foo.clonexyz.clone");
1519
1520 // Test a name with multiple appearances of the suffix.
1521 auto foo_clone_clone3 = HloInstruction::CreateParameter(
1522 0, ShapeUtil::MakeShape(F32, {}), "foo.clone.clone3");
1523 EXPECT_EQ(foo_clone_clone3->Clone()->name(), "foo.clone.clone4");
1524 }
1525
TEST_F(HloInstructionTest,Stringification)1526 TEST_F(HloInstructionTest, Stringification) {
1527 // Tests stringification of a simple op, fusion, while, and conditional.
1528 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1529 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1530 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1531 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1532
1533 HloComputation::Builder builder("TransposeDot");
1534 HloInstruction* x =
1535 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1536 HloInstruction* y =
1537 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1538 HloInstruction* reshape =
1539 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1540 DotDimensionNumbers dot_dnums;
1541 dot_dnums.add_lhs_contracting_dimensions(1);
1542 dot_dnums.add_rhs_contracting_dimensions(0);
1543 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1544 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1545
1546 auto options = HloPrintOptions().set_print_metadata(false);
1547
1548 EXPECT_EQ(dot->ToString(options),
1549 "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
1550 "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1551
1552 auto options2 = HloPrintOptions()
1553 .set_print_metadata(false)
1554 .set_print_operand_shape(false)
1555 .set_print_percent(false)
1556 .set_include_layout_in_shapes(false);
1557
1558 EXPECT_EQ(dot->ToString(options2),
1559 "dot = f32[5,20] dot(x, transpose), "
1560 "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1561
1562 auto module = CreateNewVerifiedModule();
1563 auto* computation = module->AddEntryComputation(builder.Build());
1564
1565 HloInstruction* loop = builder.AddInstruction(
1566 HloInstruction::CreateWhile(sout, computation, computation, x));
1567 EXPECT_EQ(loop->ToString(options),
1568 "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), "
1569 "condition=%TransposeDot, body=%TransposeDot");
1570
1571 auto pred = builder.AddInstruction(
1572 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1573 HloInstruction* conditional =
1574 builder.AddInstruction(HloInstruction::CreateConditional(
1575 sout, pred, x, computation, x, computation));
1576 EXPECT_EQ(conditional->ToString(options),
1577 "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, "
1578 "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), "
1579 "true_computation=%TransposeDot, false_computation=%TransposeDot");
1580 }
1581
TEST_F(HloInstructionTest,StringifyGather_0)1582 TEST_F(HloInstructionTest, StringifyGather_0) {
1583 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1584 Shape start_indices_tensor_shape =
1585 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
1586 Shape gather_result_shape =
1587 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
1588
1589 HloComputation::Builder builder("Gather");
1590 HloInstruction* input = builder.AddInstruction(
1591 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1592 HloInstruction* start_indices =
1593 builder.AddInstruction(HloInstruction::CreateParameter(
1594 1, start_indices_tensor_shape, "start_indices"));
1595
1596 HloInstruction* gather_instruction = builder.AddInstruction(
1597 HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1598 HloGatherInstruction::MakeGatherDimNumbers(
1599 /*offset_dims=*/{4, 5, 6, 7, 8},
1600 /*collapsed_slice_dims=*/{},
1601 /*start_index_map=*/{0, 1, 2, 3, 4},
1602 /*index_vector_dim=*/4),
1603 /*slice_sizes=*/{30, 29, 28, 27, 26},
1604 /*indices_are_sorted=*/false));
1605
1606 auto module = CreateNewVerifiedModule();
1607 module->AddEntryComputation(builder.Build());
1608
1609 EXPECT_EQ(gather_instruction->ToString(),
1610 "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1611 "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1612 "s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), "
1613 "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1614 "start_index_map={0,1,2,3,4}, "
1615 "index_vector_dim=4, slice_sizes={30,29,28,27,26}");
1616 }
1617
TEST_F(HloInstructionTest,StringifyGather_1)1618 TEST_F(HloInstructionTest, StringifyGather_1) {
1619 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1620 Shape start_indices_tensor_shape =
1621 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1622 Shape gather_result_shape =
1623 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1624
1625 HloComputation::Builder builder("Gather");
1626 HloInstruction* input = builder.AddInstruction(
1627 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1628 HloInstruction* start_indices =
1629 builder.AddInstruction(HloInstruction::CreateParameter(
1630 1, start_indices_tensor_shape, "start_indices"));
1631
1632 HloInstruction* gather_instruction = builder.AddInstruction(
1633 HloInstruction::CreateGather(gather_result_shape, input, start_indices,
1634 HloGatherInstruction::MakeGatherDimNumbers(
1635 /*offset_dims=*/{4, 5, 6, 7, 8},
1636 /*collapsed_slice_dims=*/{},
1637 /*start_index_map=*/{0, 1, 2, 3, 4},
1638 /*index_vector_dim=*/2),
1639 /*slice_sizes=*/{30, 29, 28, 27, 26},
1640 /*indices_are_sorted=*/false));
1641
1642 auto module = CreateNewVerifiedModule();
1643 module->AddEntryComputation(builder.Build());
1644
1645 EXPECT_EQ(gather_instruction->ToString(),
1646 "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
1647 "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1648 "s64[10,9,5,7,6]{4,3,2,1,0} %start_indices), "
1649 "offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, "
1650 "start_index_map={0,1,2,3,4}, "
1651 "index_vector_dim=2, slice_sizes={30,29,28,27,26}");
1652 }
1653
TEST_F(HloInstructionTest,StringifyScatter)1654 TEST_F(HloInstructionTest, StringifyScatter) {
1655 Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1656 Shape scatter_indices_tensor_shape =
1657 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1658 Shape scatter_updates_shape =
1659 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
1660
1661 HloComputation::Builder builder("Scatter");
1662 HloInstruction* input = builder.AddInstruction(
1663 HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
1664 HloInstruction* scatter_indices =
1665 builder.AddInstruction(HloInstruction::CreateParameter(
1666 1, scatter_indices_tensor_shape, "scatter_indices"));
1667 HloInstruction* scatter_updates =
1668 builder.AddInstruction(HloInstruction::CreateParameter(
1669 2, scatter_updates_shape, "scatter_updates"));
1670
1671 HloComputation::Builder update_builder("Scatter.update");
1672 update_builder.AddInstruction(
1673 HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
1674 update_builder.AddInstruction(
1675 HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
1676
1677 auto module = CreateNewVerifiedModule();
1678 auto* update_computation =
1679 module->AddEmbeddedComputation(update_builder.Build());
1680
1681 HloInstruction* scatter_instruction =
1682 builder.AddInstruction(HloInstruction::CreateScatter(
1683 input_tensor_shape, input, scatter_indices, scatter_updates,
1684 update_computation,
1685 HloScatterInstruction::MakeScatterDimNumbers(
1686 /*update_window_dims=*/{4, 5, 6, 7, 8},
1687 /*inserted_window_dims=*/{},
1688 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
1689 /*index_vector_dim=*/2),
1690 /*indices_are_sorted=*/false,
1691 /*unique_indices=*/false));
1692 module->AddEntryComputation(builder.Build());
1693
1694 EXPECT_EQ(
1695 scatter_instruction->ToString(),
1696 "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
1697 "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
1698 "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
1699 "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
1700 "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
1701 "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
1702 "to_apply=%Scatter.update");
1703 }
1704
TEST_F(HloInstructionTest,StringifyAsyncOps)1705 TEST_F(HloInstructionTest, StringifyAsyncOps) {
1706 const Shape s1 = ShapeUtil::MakeShape(F32, {10});
1707 const Shape s2 = ShapeUtil::MakeShape(F32, {20});
1708 const Shape s_tuple = ShapeUtil::MakeTupleShape(
1709 {ShapeUtil::MakeTupleShape({s1}), s2, ShapeUtil::MakeShape(S32, {})});
1710
1711 HloComputation::Builder async_builder("AsyncOp");
1712 HloInstruction* param = async_builder.AddInstruction(
1713 HloInstruction::CreateParameter(0, s1, "p0"));
1714 async_builder.AddInstruction(
1715 HloInstruction::CreateCustomCall(s2, {param},
1716 /*custom_call_target=*/"foo"));
1717 std::unique_ptr<HloComputation> async_computation = async_builder.Build();
1718
1719 HloComputation::Builder entry_builder("Entry");
1720 HloInstruction* entry_param = entry_builder.AddInstruction(
1721 HloInstruction::CreateParameter(0, s1, "p0"));
1722 HloInstruction* async_start =
1723 entry_builder.AddInstruction(HloInstruction::CreateAsyncStart(
1724 s_tuple, {entry_param}, async_computation.get(),
1725 /*async_group_id=*/std::nullopt,
1726 /*async_execution_thread=*/"parallel_thread"));
1727 HloInstruction* async_update =
1728 entry_builder.AddInstruction(HloInstruction::CreateAsyncUpdate(
1729 s_tuple, async_start, async_computation.get(),
1730 /*async_group_id=*/std::nullopt,
1731 /*async_execution_thread=*/"parallel_thread"));
1732 entry_builder.AddInstruction(HloInstruction::CreateAsyncDone(
1733 s2, async_update, async_computation.get(),
1734 /*async_group_id=*/std::nullopt,
1735 /*async_execution_thread=*/"parallel_thread"));
1736
1737 auto module = CreateNewVerifiedModule();
1738 module->AddEntryComputation(entry_builder.Build());
1739 module->AddEmbeddedComputation(std::move(async_computation));
1740
1741 const std::string expected_with_syntax_sugar =
1742 R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1743
1744 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1745 %p0 = f32[10]{0} parameter(0)
1746 %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", custom_call_target="foo"
1747 %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) custom-call-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
1748 ROOT %async-done = f32[20]{0} custom-call-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", custom_call_target="foo"
1749 }
1750
1751 )";
1752 EXPECT_EQ(module->ToString(), expected_with_syntax_sugar);
1753 const std::string expected_without_syntax_sugar =
1754 R"(HloModule StringifyAsyncOps, entry_computation_layout={(f32[10]{0})->f32[20]{0}}
1755
1756 %AsyncOp (p0.1: f32[10]) -> f32[20] {
1757 %p0.1 = f32[10]{0} parameter(0)
1758 ROOT %custom-call = f32[20]{0} custom-call(f32[10]{0} %p0.1), custom_call_target="foo"
1759 }, execution_thread="parallel_thread"
1760
1761 ENTRY %Entry (p0: f32[10]) -> f32[20] {
1762 %p0 = f32[10]{0} parameter(0)
1763 %async-start = ((f32[10]{0}), f32[20]{0}, s32[]) async-start(f32[10]{0} %p0), async_execution_thread="parallel_thread", calls=%AsyncOp
1764 %async-update = ((f32[10]{0}), f32[20]{0}, s32[]) async-update(((f32[10]{0}), f32[20]{0}, s32[]) %async-start), async_execution_thread="parallel_thread", calls=%AsyncOp
1765 ROOT %async-done = f32[20]{0} async-done(((f32[10]{0}), f32[20]{0}, s32[]) %async-update), async_execution_thread="parallel_thread", calls=%AsyncOp
1766 }
1767
1768 )";
1769 auto options = HloPrintOptions().set_syntax_sugar_async_ops(false);
1770 EXPECT_EQ(module->ToString(options), expected_without_syntax_sugar);
1771 }
1772
TEST_F(HloInstructionTest,CanonicalStringificationFusion)1773 TEST_F(HloInstructionTest, CanonicalStringificationFusion) {
1774 // Tests stringification of a simple op, fusion, while, and conditional.
1775 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1776 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1777 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1778 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1779
1780 HloComputation::Builder builder("TransposeDot");
1781 HloInstruction* x =
1782 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1783 HloInstruction* y =
1784 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1785 HloInstruction* reshape =
1786 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1787 DotDimensionNumbers dot_dnums;
1788 dot_dnums.add_lhs_contracting_dimensions(1);
1789 dot_dnums.add_rhs_contracting_dimensions(0);
1790 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1791 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1792
1793 auto options = HloPrintOptions().Canonical();
1794
1795 EXPECT_EQ(dot->ToString(options),
1796 "f32[5,20]{1,0} dot(f32[5,10]{1,0}, f32[10,20]{1,0}), "
1797 "lhs_contracting_dims={1}, rhs_contracting_dims={0}");
1798
1799 auto module = CreateNewVerifiedModule();
1800 auto* computation = module->AddEntryComputation(builder.Build());
1801 constexpr char kParallelThreadName[] = "parallel_thread";
1802 computation->SetExecutionThread(kParallelThreadName);
1803 HloInstruction* fusion = computation->CreateFusionInstruction(
1804 {dot, reshape}, HloInstruction::FusionKind::kLoop);
1805 fusion->set_called_computations_execution_thread(
1806 kParallelThreadName,
1807 /*skip_async_execution_thread_overwrite*/ false);
1808
1809 const std::string expected_fusion =
1810 R"(f32[5,20]{1,0} fusion(f32[5,10]{1,0}, f32[20,10]{1,0}), kind=kLoop, calls=
1811 {
1812 tmp_0 = f32[5,10]{1,0} parameter(0)
1813 tmp_1 = f32[20,10]{1,0} parameter(1)
1814 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1815 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1816 }, execution_thread="parallel_thread")";
1817 EXPECT_EQ(fusion->ToString(options), expected_fusion);
1818 }
1819
TEST_F(HloInstructionTest,CanonicalStringificationWhile)1820 TEST_F(HloInstructionTest, CanonicalStringificationWhile) {
1821 // Tests stringification of a simple op, fusion, while, and conditional.
1822 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1823 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1824 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1825 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1826
1827 HloComputation::Builder builder("TransposeDot");
1828 HloInstruction* x =
1829 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1830 HloInstruction* y =
1831 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1832 HloInstruction* reshape =
1833 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1834 DotDimensionNumbers dot_dnums;
1835 dot_dnums.add_lhs_contracting_dimensions(1);
1836 dot_dnums.add_rhs_contracting_dimensions(0);
1837 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1838 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1839
1840 auto module = CreateNewVerifiedModule();
1841 auto* computation = module->AddEntryComputation(builder.Build());
1842 computation->CreateFusionInstruction({dot, reshape},
1843 HloInstruction::FusionKind::kLoop);
1844
1845 HloInstruction* loop = builder.AddInstruction(
1846 HloInstruction::CreateWhile(sout, computation, computation, x));
1847
1848 auto options = HloPrintOptions().Canonical();
1849 const std::string expected_loop =
1850 R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
1851 {
1852 tmp_0 = f32[5,10]{1,0} parameter(0)
1853 tmp_1 = f32[20,10]{1,0} parameter(1)
1854 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1855 {
1856 tmp_0 = f32[5,10]{1,0} parameter(0)
1857 tmp_1 = f32[20,10]{1,0} parameter(1)
1858 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1859 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1860 }
1861 }, body=
1862 {
1863 tmp_0 = f32[5,10]{1,0} parameter(0)
1864 tmp_1 = f32[20,10]{1,0} parameter(1)
1865 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1866 {
1867 tmp_0 = f32[5,10]{1,0} parameter(0)
1868 tmp_1 = f32[20,10]{1,0} parameter(1)
1869 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1870 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1871 }
1872 })";
1873 EXPECT_EQ(loop->ToString(options), expected_loop);
1874 }
1875
TEST_F(HloInstructionTest,CanonicalStringificationConditional)1876 TEST_F(HloInstructionTest, CanonicalStringificationConditional) {
1877 // Tests stringification of a simple op, fusion, while, and conditional.
1878 const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
1879 const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
1880 const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
1881 const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
1882
1883 HloComputation::Builder builder("TransposeDot");
1884 HloInstruction* x =
1885 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
1886 HloInstruction* y =
1887 builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
1888 HloInstruction* reshape =
1889 builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
1890 DotDimensionNumbers dot_dnums;
1891 dot_dnums.add_lhs_contracting_dimensions(1);
1892 dot_dnums.add_rhs_contracting_dimensions(0);
1893 HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
1894 sout, x, reshape, dot_dnums, DefaultPrecisionConfig(2)));
1895
1896 auto module = CreateNewVerifiedModule();
1897 auto* computation = module->AddEntryComputation(builder.Build());
1898 computation->CreateFusionInstruction({dot, reshape},
1899 HloInstruction::FusionKind::kLoop);
1900
1901 builder.AddInstruction(
1902 HloInstruction::CreateWhile(sout, computation, computation, x));
1903
1904 auto pred = builder.AddInstruction(
1905 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
1906 HloInstruction* conditional =
1907 builder.AddInstruction(HloInstruction::CreateConditional(
1908 sout, pred, x, computation, x, computation));
1909 auto options = HloPrintOptions().Canonical();
1910 const std::string expected_conditional =
1911 R"(f32[5,20]{1,0} conditional(pred[], f32[5,10]{1,0}, f32[5,10]{1,0}), true_computation=
1912 {
1913 tmp_0 = f32[5,10]{1,0} parameter(0)
1914 tmp_1 = f32[20,10]{1,0} parameter(1)
1915 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1916 {
1917 tmp_0 = f32[5,10]{1,0} parameter(0)
1918 tmp_1 = f32[20,10]{1,0} parameter(1)
1919 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1920 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1921 }
1922 }, false_computation=
1923 {
1924 tmp_0 = f32[5,10]{1,0} parameter(0)
1925 tmp_1 = f32[20,10]{1,0} parameter(1)
1926 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
1927 {
1928 tmp_0 = f32[5,10]{1,0} parameter(0)
1929 tmp_1 = f32[20,10]{1,0} parameter(1)
1930 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
1931 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
1932 }
1933 })";
1934 EXPECT_EQ(conditional->ToString(options), expected_conditional);
1935 }
1936
TEST_F(HloInstructionTest,CheckDeepClone)1937 TEST_F(HloInstructionTest, CheckDeepClone) {
1938 const char* const hlo_string = R"(
1939 HloModule Module
1940
1941 addy (lhs: s32[], rhs: s32[]) -> s32[] {
1942 lhs = s32[] parameter(0)
1943 rhs = s32[] parameter(1)
1944 ROOT zadd = s32[] add(lhs, rhs)
1945 }
1946
1947 calla (x: s32[]) -> s32[] {
1948 x = s32[] parameter(0)
1949 reduce = s32[] reduce-window(x, x), to_apply=addy
1950 ROOT xadd = s32[] add(x, reduce)
1951 }
1952
1953 body (bparam: s32[]) -> s32[] {
1954 constant = s32[] constant(1)
1955 bparam = s32[] parameter(0)
1956 v = s32[] call(bparam), to_apply=calla
1957 ROOT add = s32[] add(constant, bparam)
1958 }
1959
1960 condition (cparam: s32[]) -> pred[] {
1961 xconstant = s32[] constant(5)
1962 cparam = s32[] parameter(0)
1963 ROOT greater-than = pred[] compare(xconstant, cparam), direction=GT
1964 }
1965
1966 ENTRY entry (param: s32[]) -> s32[] {
1967 eparam = s32[] parameter(0)
1968 ROOT while = s32[] while(eparam), condition=condition, body=body
1969 }
1970 )";
1971 // Check that deep clones really deep clones every instruction and
1972 // computations, without leaving dangling pointers to the old module.
1973 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1974 ParseAndReturnVerifiedModule(hlo_string));
1975 std::unique_ptr<HloModule> clone = module->Clone();
1976 for (HloComputation* computation : clone->computations()) {
1977 EXPECT_EQ(computation->parent(), clone.get());
1978 for (HloInstruction* instruction : computation->instructions()) {
1979 EXPECT_EQ(instruction->parent()->parent(), clone.get());
1980 }
1981 }
1982 }
1983
TEST_F(HloInstructionTest,IdenticalAccountsForBackendConfig)1984 TEST_F(HloInstructionTest, IdenticalAccountsForBackendConfig) {
1985 const Shape shape = ShapeUtil::MakeShape(F32, {42});
1986 HloComputation::Builder builder("test");
1987 HloInstruction* p =
1988 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p"));
1989
1990 HloInstruction* add1 = builder.AddInstruction(
1991 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1992 HloInstruction* add2 = builder.AddInstruction(
1993 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p, p));
1994
1995 EXPECT_TRUE(add1->Identical(*add2));
1996 add1->set_raw_backend_config_string("abc");
1997 EXPECT_FALSE(add1->Identical(*add2));
1998 }
1999
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallWindow)2000 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallWindow) {
2001 auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2002 /*operands=*/{},
2003 /*custom_call_target=*/"foo");
2004 auto instr2 = instr1->Clone();
2005 EXPECT_TRUE(instr1->Identical(*instr2));
2006
2007 Window w = window_util::MakeWindow({1, 2, 3});
2008 instr1->set_window(w);
2009 EXPECT_FALSE(instr1->Identical(*instr2));
2010 }
2011
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallDnums)2012 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallDnums) {
2013 auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2014 /*operands=*/{},
2015 /*custom_call_target=*/"foo");
2016 auto instr2 = instr1->Clone();
2017 EXPECT_TRUE(instr1->Identical(*instr2));
2018
2019 ConvolutionDimensionNumbers dnums;
2020 dnums.set_output_batch_dimension(42);
2021 instr1->set_convolution_dimension_numbers(dnums);
2022 EXPECT_FALSE(instr1->Identical(*instr2));
2023 }
2024
TEST_F(HloInstructionTest,IdenticalAccountsForCustomCallHasSideEffect)2025 TEST_F(HloInstructionTest, IdenticalAccountsForCustomCallHasSideEffect) {
2026 auto instr1 = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2027 /*operands=*/{},
2028 /*custom_call_target=*/"foo");
2029 auto instr2 = instr1->Clone();
2030 EXPECT_TRUE(instr1->Identical(*instr2));
2031
2032 auto custom_call_instr1 = Cast<HloCustomCallInstruction>(instr1.get());
2033 custom_call_instr1->set_custom_call_has_side_effect(true);
2034 EXPECT_FALSE(instr1->Identical(*instr2));
2035 }
2036
TEST_F(HloInstructionTest,CloneWindowOnCustomCall)2037 TEST_F(HloInstructionTest, CloneWindowOnCustomCall) {
2038 auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2039 /*operands=*/{},
2040 /*custom_call_target=*/"foo");
2041 Window w = window_util::MakeWindow({1, 2, 3});
2042 instr->set_window(w);
2043 auto clone = instr->Clone();
2044 EXPECT_TRUE(protobuf_util::ProtobufEquals(clone->window(), w))
2045 << clone->window().DebugString();
2046 }
2047
TEST_F(HloInstructionTest,CloneDnumsOnCustomCall)2048 TEST_F(HloInstructionTest, CloneDnumsOnCustomCall) {
2049 auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2050 /*operands=*/{},
2051 /*custom_call_target=*/"foo");
2052 ConvolutionDimensionNumbers dnums;
2053 dnums.set_output_batch_dimension(42);
2054 instr->set_convolution_dimension_numbers(dnums);
2055 auto clone = instr->Clone();
2056 EXPECT_TRUE(protobuf_util::ProtobufEquals(
2057 clone->convolution_dimension_numbers(), dnums))
2058 << clone->convolution_dimension_numbers().DebugString();
2059 }
2060
TEST_F(HloInstructionTest,CloneHasSideEffectOnCustomCall)2061 TEST_F(HloInstructionTest, CloneHasSideEffectOnCustomCall) {
2062 auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2063 /*operands=*/{},
2064 /*custom_call_target=*/"foo");
2065 auto custom_call_instr = Cast<HloCustomCallInstruction>(instr.get());
2066 EXPECT_FALSE(custom_call_instr->custom_call_has_side_effect());
2067 custom_call_instr->set_custom_call_has_side_effect(true);
2068 EXPECT_TRUE(custom_call_instr->custom_call_has_side_effect());
2069 auto clone = instr->Clone();
2070 auto custom_call_clone = Cast<HloCustomCallInstruction>(clone.get());
2071 EXPECT_TRUE(custom_call_clone->custom_call_has_side_effect());
2072 }
2073
TEST_F(HloInstructionTest,CustomCallHasSideEffect)2074 TEST_F(HloInstructionTest, CustomCallHasSideEffect) {
2075 auto instr = HloInstruction::CreateCustomCall(ShapeUtil::MakeShape(F32, {}),
2076 /*operands=*/{},
2077 /*custom_call_target=*/"foo");
2078 auto custom_call_instr = Cast<HloCustomCallInstruction>(instr.get());
2079 EXPECT_FALSE(instr->HasSideEffect());
2080 custom_call_instr->set_custom_call_has_side_effect(true);
2081 EXPECT_TRUE(instr->HasSideEffect());
2082 }
2083
TEST_F(HloInstructionTest,PreserveOperandPrecisionOnCloneConv)2084 TEST_F(HloInstructionTest, PreserveOperandPrecisionOnCloneConv) {
2085 constexpr char kHloString[] = R"(
2086 HloModule test_module
2087 ENTRY test {
2088 arg0 = f32[1,2,1] parameter(0)
2089 arg1 = f32[1,1,1] parameter(1)
2090 ROOT conv = f32[1,2,1] convolution(arg0, arg1), window={size=1},
2091 dim_labels=b0f_0io->b0f, operand_precision={high,default}
2092 })";
2093 TF_ASSERT_OK_AND_ASSIGN(auto module,
2094 ParseAndReturnVerifiedModule(kHloString));
2095 auto* conv = module->entry_computation()->root_instruction();
2096
2097 auto clone = conv->Clone();
2098 EXPECT_THAT(
2099 clone->precision_config().operand_precision(),
2100 ::testing::ElementsAre(PrecisionConfig::HIGH, PrecisionConfig::DEFAULT));
2101 }
2102
TEST_F(HloInstructionTest,PreserveOuterDimensionPartitionsOnClone)2103 TEST_F(HloInstructionTest, PreserveOuterDimensionPartitionsOnClone) {
2104 constexpr char kHloString[] = R"(
2105 HloModule test_module
2106 ENTRY test {
2107 ROOT iota = f32[100] iota(), iota_dimension=0, outer_dimension_partitions={0, 50}
2108 })";
2109 TF_ASSERT_OK_AND_ASSIGN(auto module,
2110 ParseAndReturnVerifiedModule(kHloString));
2111 auto* iota = module->entry_computation()->root_instruction();
2112
2113 auto clone = iota->Clone();
2114 EXPECT_THAT(clone->outer_dimension_partitions(),
2115 ::testing::ElementsAre(0, 50));
2116 }
2117
TEST_F(HloInstructionTest,ReuseReshapeOfFusionParameter)2118 TEST_F(HloInstructionTest, ReuseReshapeOfFusionParameter) {
2119 // Create a fusion node which uses the reshape of a parameter twice. Because
2120 // it's the same reshape, this counts as UseKind::kUsePermutingElements, which
2121 // is exposed publicly as "does not reuse this operand".
2122 constexpr char kHloString[] = R"(
2123 HloModule test_module
2124 f {
2125 p = f32[3,2] parameter(0)
2126 r = f32[2,3] reshape(p)
2127 x = f32[2,3] multiply(r, r)
2128 y = f32[2,3] add(r, r)
2129 ROOT sum = f32[2,3] add(x, y)
2130 }
2131 ENTRY test {
2132 p = f32[3,2] parameter(0)
2133 ROOT fusion = f32[2,3] fusion(p), calls=f, kind=kLoop
2134 })";
2135 TF_ASSERT_OK_AND_ASSIGN(auto module,
2136 ParseAndReturnVerifiedModule(kHloString));
2137 const HloInstruction* root = module->entry_computation()->root_instruction();
2138 EXPECT_FALSE(root->ReusesOperandElements(0));
2139 }
2140
TEST_F(HloInstructionTest,ReuseMultipleReshapesOfFusionParameter)2141 TEST_F(HloInstructionTest, ReuseMultipleReshapesOfFusionParameter) {
2142 // Create a fusion node which uses two different reshapes of a parameter
2143 // twice. Because they're not the same reshapes, this counts as
2144 // UseKind::kUsePermutingElements, which is exposed publicly as "does reuse
2145 // this operand".
2146 constexpr char kHloString[] = R"(
2147 HloModule test_module
2148 f {
2149 p = f32[3,2] parameter(0)
2150 r1 = f32[2,3] reshape(p)
2151 r2 = f32[6,1] reshape(p)
2152 ROOT result = (f32[2,3], f32[6,1]) tuple(r1, r2)
2153 }
2154 ENTRY test {
2155 p = f32[3,2] parameter(0)
2156 ROOT fusion = (f32[2,3], f32[6,1]) fusion(p), calls=f, kind=kLoop
2157 })";
2158 TF_ASSERT_OK_AND_ASSIGN(auto module,
2159 ParseAndReturnVerifiedModule(kHloString));
2160 const HloInstruction* root = module->entry_computation()->root_instruction();
2161 EXPECT_TRUE(root->ReusesOperandElements(0));
2162 }
2163
TEST_F(HloInstructionTest,BitcastDoesNotReuseElements)2164 TEST_F(HloInstructionTest, BitcastDoesNotReuseElements) {
2165 constexpr char kHloString[] = R"(
2166 HloModule test_module
2167 ENTRY test {
2168 p = f32[3,2]{0,1} parameter(0)
2169 ROOT bitcast = f32[6] bitcast(p)
2170 })";
2171 TF_ASSERT_OK_AND_ASSIGN(auto module,
2172 ParseAndReturnVerifiedModule(kHloString));
2173 const HloInstruction* root = module->entry_computation()->root_instruction();
2174 EXPECT_FALSE(root->ReusesOperandElements(0));
2175 }
2176
TEST_F(HloInstructionTest,GatherDoesNotReuseElements)2177 TEST_F(HloInstructionTest, GatherDoesNotReuseElements) {
2178 constexpr char kHloString[] = R"(
2179 HloModule test_module
2180
2181 ENTRY test {
2182 input = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
2183 idx = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
2184 ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0}
2185 gather(input, idx), offset_dims={4,5,6,7,8}, collapsed_slice_dims={},
2186 start_index_map={0,1,2,3,4}, index_vector_dim=4,
2187 slice_sizes={30,29,28,27,26}
2188 })";
2189 TF_ASSERT_OK_AND_ASSIGN(auto module,
2190 ParseAndReturnVerifiedModule(kHloString));
2191 const HloInstruction* root = module->entry_computation()->root_instruction();
2192 EXPECT_FALSE(root->ReusesOperandElements(0));
2193 EXPECT_FALSE(root->ReusesOperandElements(1));
2194 }
2195
TEST_F(HloInstructionTest,BackendConfigCanContainNonFiniteFloats)2196 TEST_F(HloInstructionTest, BackendConfigCanContainNonFiniteFloats) {
2197 HloComputation::Builder b(TestName());
2198 Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
2199 auto p0 = b.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2200 DotDimensionNumbers dot_dnums;
2201 dot_dnums.add_lhs_contracting_dimensions(1);
2202 dot_dnums.add_rhs_contracting_dimensions(0);
2203 auto dot = b.AddInstruction(HloInstruction::CreateDot(
2204 shape, p0, p0, dot_dnums, DefaultPrecisionConfig(2)));
2205
2206 gpu::GemmBackendConfig orig_config;
2207 orig_config.set_alpha_real(std::numeric_limits<double>::infinity());
2208 orig_config.set_alpha_imag(std::numeric_limits<double>::quiet_NaN());
2209 TF_ASSERT_OK(dot->set_backend_config(orig_config));
2210
2211 TF_ASSERT_OK_AND_ASSIGN(auto new_config,
2212 dot->backend_config<gpu::GemmBackendConfig>());
2213 EXPECT_GT(new_config.alpha_real(), std::numeric_limits<double>::max());
2214 EXPECT_NE(new_config.alpha_imag(), new_config.alpha_imag());
2215 }
2216
2217 } // namespace
2218 } // namespace xla
2219