1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17
18 #include <string>
19
20 #include "tensorflow/compiler/xla/client/value_inference.h"
21 #include "tensorflow/compiler/xla/client/xla_computation.h"
22 #include "tensorflow/compiler/xla/debug_options_flags.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33
34 namespace xla {
35
36 namespace {
37
38 namespace op = xla::testing::opcode_matchers;
39
40 using ::testing::HasSubstr;
41
42 // TODO(b/74197823): Move the tests to service/.
43 class XlaBuilderTest : public ::testing::Test {
44 protected:
BuildHloModule(XlaBuilder * b)45 StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) {
46 TF_ASSIGN_OR_RETURN(XlaComputation computation,
47 b->Build(/*remove_dynamic_dimensions=*/false));
48 const HloModuleProto& proto = computation.proto();
49 TF_ASSIGN_OR_RETURN(const auto& config,
50 HloModule::CreateModuleConfigFromProto(
51 proto, GetDebugOptionsFromFlags()));
52 return HloModule::CreateFromProto(proto, config);
53 }
54
55 // Overload which explicitly specifies the root instruction.
BuildHloModule(XlaBuilder * b,XlaOp root)56 StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
57 XlaOp root) {
58 TF_ASSIGN_OR_RETURN(XlaComputation computation,
59 b->Build(root, /*remove_dynamic_dimensions=*/false));
60 const HloModuleProto& proto = computation.proto();
61 TF_ASSIGN_OR_RETURN(const auto& config,
62 HloModule::CreateModuleConfigFromProto(
63 proto, GetDebugOptionsFromFlags()));
64 return HloModule::CreateFromProto(proto, config);
65 }
66
67 // Returns the name of the test currently being run.
TestName() const68 std::string TestName() const {
69 return ::testing::UnitTest::GetInstance()->current_test_info()->name();
70 }
71 };
72
TEST_F(XlaBuilderTest,OnePlusTwo)73 TEST_F(XlaBuilderTest, OnePlusTwo) {
74 XlaBuilder b(TestName());
75 Add(ConstantR0<float>(&b, 1.0), ConstantR0<float>(&b, 2.0));
76 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
77 auto root = module->entry_computation()->root_instruction();
78 EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
79 }
80
TEST_F(XlaBuilderTest,UnaryOperatorsBuildExpectedHLO)81 TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) {
82 auto test_unary_operator =
83 [&](std::function<XlaOp(XlaOp)> op,
84 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
85 XlaBuilder b(TestName());
86 op(ConstantR0<int32_t>(&b, 1));
87 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
88 auto root = module->entry_computation()->root_instruction();
89 EXPECT_THAT(root, matches_pattern);
90 };
91 test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant()));
92 test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant()));
93 }
94
TEST_F(XlaBuilderTest,BinaryOperatorsBuildExpectedHLO)95 TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) {
96 auto test_binary_operator =
97 [&](std::function<XlaOp(XlaOp, XlaOp)> op,
98 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
99 XlaBuilder b(TestName());
100 op(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
101 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
102 auto root = module->entry_computation()->root_instruction();
103 EXPECT_THAT(root, matches_pattern);
104 };
105
106 test_binary_operator([](XlaOp x, XlaOp y) { return x + y; },
107 op::Add(op::Constant(), op::Constant()));
108 test_binary_operator([](XlaOp x, XlaOp y) { return x - y; },
109 op::Subtract(op::Constant(), op::Constant()));
110 test_binary_operator([](XlaOp x, XlaOp y) { return x * y; },
111 op::Multiply(op::Constant(), op::Constant()));
112 test_binary_operator([](XlaOp x, XlaOp y) { return x / y; },
113 op::Divide(op::Constant(), op::Constant()));
114
115 test_binary_operator([](XlaOp x, XlaOp y) { return x & y; },
116 op::And(op::Constant(), op::Constant()));
117 test_binary_operator([](XlaOp x, XlaOp y) { return x | y; },
118 op::Or(op::Constant(), op::Constant()));
119 test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; },
120 op::Xor(op::Constant(), op::Constant()));
121 test_binary_operator([](XlaOp x, XlaOp y) { return x << y; },
122 op::ShiftLeft(op::Constant(), op::Constant()));
123 test_binary_operator(
124 [](XlaOp x, XlaOp y) { return x >> y; },
125 op::ShiftRightArithmetic(op::Constant(), op::Constant()));
126
127 auto test_unsigned_binary_operator =
128 [&](std::function<XlaOp(XlaOp, XlaOp)> op,
129 ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
130 XlaBuilder b(TestName());
131 op(ConstantR0<uint32_t>(&b, 1), ConstantR0<uint32_t>(&b, 2));
132 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
133 auto root = module->entry_computation()->root_instruction();
134 EXPECT_THAT(root, matches_pattern);
135 };
136 test_unsigned_binary_operator(
137 [](XlaOp x, XlaOp y) { return x >> y; },
138 op::ShiftRightLogical(op::Constant(), op::Constant()));
139 }
140
TEST_F(XlaBuilderTest,VariadicAnd)141 TEST_F(XlaBuilderTest, VariadicAnd) {
142 XlaBuilder b(TestName());
143 Shape s = ShapeUtil::MakeShape(PRED, {});
144 And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
145 Parameter(&b, 2, s, "p2"));
146 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
147 // Don't specify in the test whether And(x, y, z) is right- or
148 // left-associative; accept either one.
149 EXPECT_THAT(
150 module->entry_computation()->root_instruction(),
151 ::testing::AnyOf(op::And(op::Parameter(0),
152 op::And(op::Parameter(1), op::Parameter(2))),
153 op::And(op::And(op::Parameter(0), op::Parameter(1)),
154 op::Parameter(2))));
155 }
156
TEST_F(XlaBuilderTest,VariadicOr)157 TEST_F(XlaBuilderTest, VariadicOr) {
158 XlaBuilder b(TestName());
159 Shape s = ShapeUtil::MakeShape(PRED, {});
160 Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
161 Parameter(&b, 2, s, "p2"));
162 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
163 // Don't specify in the test whether Or(x, y, z) is right- or
164 // left-associative; accept either one.
165 EXPECT_THAT(
166 module->entry_computation()->root_instruction(),
167 ::testing::AnyOf(
168 op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))),
169 op::Or(op::Or(op::Parameter(0), op::Parameter(1)),
170 op::Parameter(2))));
171 }
172
TEST_F(XlaBuilderTest,ShiftRightOperatorOnNonIntegerProducesError)173 TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) {
174 XlaBuilder b(TestName());
175 ConstantR0<float>(&b, 1) >> ConstantR0<float>(&b, 2);
176 auto statusor = b.Build();
177 ASSERT_FALSE(statusor.ok());
178 EXPECT_THAT(
179 statusor.status().error_message(),
180 HasSubstr("Argument to >> operator does not have an integral type"));
181 }
182
TEST_F(XlaBuilderTest,ParamPlusConstantHasScalarBroadcast)183 TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
184 XlaBuilder b(TestName());
185 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
186 Add(x, ConstantR0<float>(&b, 1.0));
187 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
188 auto root = module->entry_computation()->root_instruction();
189 EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant())));
190 }
191
TEST_F(XlaBuilderTest,ParamPlusParamHasBroadcast)192 TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
193 XlaBuilder b(TestName());
194 const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
195 const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
196 auto x = Parameter(&b, 0, x_shape, "x");
197 auto y = Parameter(&b, 1, y_shape, "y");
198 auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1});
199
200 TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add));
201 EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
202
203 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
204 auto root = module->entry_computation()->root_instruction();
205 EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1))));
206 }
207
TEST_F(XlaBuilderTest,XPlusX)208 TEST_F(XlaBuilderTest, XPlusX) {
209 XlaBuilder b(TestName());
210 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
211 Add(x, x);
212 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
213 auto root = module->entry_computation()->root_instruction();
214 EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
215 }
216
TEST_F(XlaBuilderTest,ShapeInferenceError)217 TEST_F(XlaBuilderTest, ShapeInferenceError) {
218 XlaBuilder b(TestName());
219 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
220 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
221 Add(x, y);
222 auto statusor = BuildHloModule(&b);
223 ASSERT_FALSE(statusor.ok());
224 EXPECT_THAT(statusor.status().error_message(),
225 HasSubstr("Shapes must be equal rank"));
226 }
227
TEST_F(XlaBuilderTest,DynamicDimensionReshapeToR0)228 TEST_F(XlaBuilderTest, DynamicDimensionReshapeToR0) {
229 XlaBuilder b(TestName());
230 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1}), "x");
231 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "dyn_dim");
232 auto dx = SetDimensionSize(x, y, 0);
233 Reshape(dx, {});
234 auto statusor = BuildHloModule(&b);
235 ASSERT_TRUE(statusor.ok());
236 }
237
TEST_F(XlaBuilderTest,ParameterAlreadyRegistered)238 TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
239 XlaBuilder b_call("add");
240 Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x");
241
242 XlaBuilder b(TestName());
243 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x");
244 auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y");
245 Add(x, y);
246 auto statusor = BuildHloModule(&b);
247 ASSERT_FALSE(statusor.ok());
248 EXPECT_THAT(statusor.status().error_message(),
249 HasSubstr("parameter 0 already registered"));
250 }
251
TEST_F(XlaBuilderTest,Call)252 TEST_F(XlaBuilderTest, Call) {
253 XlaBuilder b_call("the_only_to_apply");
254 auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
255 auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
256 Add(p0, p1);
257 TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
258 XlaBuilder b(TestName());
259 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
260 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
261 auto one = ConstantR0<float>(&b, 1);
262 auto two = ConstantR0<float>(&b, 2);
263 Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
264 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
265 auto root = module->entry_computation()->root_instruction();
266 EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
267 op::Call(op::Constant(), op::Constant())));
268 }
269
TEST_F(XlaBuilderTest,BinopHasDegenerateBroadcast)270 TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
271 XlaBuilder b(TestName());
272 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
273 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
274 Add(x, y);
275 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
276
277 // Expected:
278 //
279 // x: f32[1,2,3] y: f32[1,2,1]
280 // | |
281 // | reshape: f32[1,2]
282 // | |
283 // | broadcast: f32[1,2,3]
284 // \ /
285 // add
286 auto root = module->entry_computation()->root_instruction();
287 EXPECT_THAT(root, op::Add(op::Parameter(0),
288 op::Broadcast(op::Reshape(op::Parameter(1)))));
289 }
290
TEST_F(XlaBuilderTest,BinopHasInDimAndDegenerateBroadcast)291 TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
292 XlaBuilder b(TestName());
293 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
294 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
295 Add(x, y, /*broadcast_dimensions=*/{0, 1});
296 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
297
298 // The binary operation has in-dim broadcast and degenerate broadcast, should
299 // first do the in-dim broadcast then convert the degenerate broadcast into a
300 // reshape and a broadcast.
301 //
302 // Expected:
303 //
304 // x: f32[2,3] y: f32[2,1,4]
305 // | |
306 // broadcast: f32[2,3,4] reshape: f32[2,4]
307 // | |
308 // | broadcast: f32[2,3,4]
309 // \ /
310 // add
311 auto root = module->entry_computation()->root_instruction();
312 EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)),
313 op::Broadcast(op::Reshape(op::Parameter(1)))));
314 }
315
TEST_F(XlaBuilderTest,BroadcastInDim)316 TEST_F(XlaBuilderTest, BroadcastInDim) {
317 XlaBuilder b(TestName());
318 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
319 BroadcastInDim(x, {2, 4, 3},
320 /*broadcast_dimensions=*/{0, 2});
321 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
322 auto root = module->entry_computation()->root_instruction();
323 EXPECT_THAT(root, op::Broadcast());
324 }
325
TEST_F(XlaBuilderTest,BroadcastInDimWithDegeneratedDim)326 TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) {
327 XlaBuilder b(TestName());
328 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
329 BroadcastInDim(x, {2, 3, 4},
330 /*broadcast_dimensions=*/{0, 1, 2});
331 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
332 EXPECT_THAT(module->entry_computation()->root_instruction(),
333 op::Broadcast(op::Reshape(op::Broadcast())));
334 }
335
TEST_F(XlaBuilderTest,BroadcastInDimWithNegativeSize)336 TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
337 XlaBuilder b(TestName());
338 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
339 BroadcastInDim(x, {-3, 3, 4},
340 /*broadcast_dimensions=*/{0, 1, 2});
341 auto statusor = BuildHloModule(&b);
342 ASSERT_FALSE(statusor.ok());
343 EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
344 }
345
TEST_F(XlaBuilderTest,OperandFromWrongBuilder)346 TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
347 XlaBuilder b1("b1");
348 auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0");
349 XlaBuilder builder("main");
350 auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p");
351 Add(p, p0);
352 auto statusor = builder.Build();
353 ASSERT_FALSE(statusor.ok());
354 EXPECT_THAT(
355 statusor.status().error_message(),
356 HasSubstr(
357 "built by builder 'b1', but is trying to use it in builder 'main'"));
358 }
359
TEST_F(XlaBuilderTest,ReshapeDefaultOrder)360 TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
361 XlaBuilder b(TestName());
362 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
363 Reshape(x, /*new_sizes=*/{6, 35});
364 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
365 auto root = module->entry_computation()->root_instruction();
366 EXPECT_THAT(root, op::Reshape(op::Parameter()));
367 }
368
TEST_F(XlaBuilderTest,ReshapeHasTranspose)369 TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
370 XlaBuilder b(TestName());
371 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
372 Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
373 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
374 auto root = module->entry_computation()->root_instruction();
375 EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter())));
376 }
377
TEST_F(XlaBuilderTest,Transpose)378 TEST_F(XlaBuilderTest, Transpose) {
379 XlaBuilder b(TestName());
380 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
381 Transpose(x, /*permutation=*/{1, 0});
382 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
383 auto root = module->entry_computation()->root_instruction();
384 EXPECT_THAT(root, op::Transpose(op::Parameter()));
385 }
386
TEST_F(XlaBuilderTest,AllGatherR1)387 TEST_F(XlaBuilderTest, AllGatherR1) {
388 XlaBuilder b(TestName());
389 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x");
390 AllGather(x, /*all_gather_dimension=*/0, /*shard_count=*/4);
391 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
392 auto root = module->entry_computation()->root_instruction();
393
394 EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
395 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {16})));
396 }
397
TEST_F(XlaBuilderTest,AllGatherR2)398 TEST_F(XlaBuilderTest, AllGatherR2) {
399 XlaBuilder b(TestName());
400 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
401 AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4);
402 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
403 auto root = module->entry_computation()->root_instruction();
404
405 EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
406 EXPECT_TRUE(
407 ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
408 }
409
TEST_F(XlaBuilderTest,ReduceScatter)410 TEST_F(XlaBuilderTest, ReduceScatter) {
411 XlaBuilder b(TestName());
412 XlaComputation to_apply;
413 {
414 auto sub_builder = b.CreateSubBuilder("add");
415 auto arg0 =
416 Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x");
417 auto arg1 =
418 Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y");
419 Add(arg0, arg1);
420 TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build());
421 }
422 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
423 ReplicaGroup group;
424 group.add_replica_ids(0);
425 group.add_replica_ids(1);
426 ReduceScatter(x, to_apply, /*scatter_dimension=*/1, /*shard_count=*/2,
427 /*replica_groups=*/{group});
428 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
429 auto root = module->entry_computation()->root_instruction();
430
431 EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter);
432 EXPECT_TRUE(
433 ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
434 }
435
TEST_F(XlaBuilderTest,AllToAll)436 TEST_F(XlaBuilderTest, AllToAll) {
437 XlaBuilder b(TestName());
438 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
439 AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
440 /*split_count=*/2);
441 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
442 auto root = module->entry_computation()->root_instruction();
443
444 // AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
445 EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
446 EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->opcode(),
447 HloOpcode::kAllToAll);
448 EXPECT_TRUE(
449 ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
450 }
451
452 // Test the special case where split_dimension is the same as concat_dimension.
TEST_F(XlaBuilderTest,AllToAllSpecial)453 TEST_F(XlaBuilderTest, AllToAllSpecial) {
454 XlaBuilder b(TestName());
455 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x");
456 AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0,
457 /*split_count=*/2);
458 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
459 auto root = module->entry_computation()->root_instruction();
460
461 // AllToAll is converted into a single all-to-all HloInstruction.
462 EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
463 EXPECT_TRUE(
464 ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8})));
465 }
466
TEST_F(XlaBuilderTest,AllToAllTuple)467 TEST_F(XlaBuilderTest, AllToAllTuple) {
468 XlaBuilder b(TestName());
469 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 4}), "p0");
470 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 4}), "p1");
471 ReplicaGroup replica_group;
472 replica_group.add_replica_ids(0);
473 replica_group.add_replica_ids(1);
474
475 AllToAllTuple({p0, p1}, {replica_group}, LayoutUtil::MakeAscendingLayout(2));
476 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
477 auto root = module->entry_computation()->root_instruction();
478
479 // AllToAll is converted into a single all-to-all HloInstruction.
480 EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
481 auto expected_shape =
482 ShapeUtil::MakeShapeWithLayout(F32, /* dimensions= */ {2, 4},
483 /* minor_to_major= */ {0, 1});
484 EXPECT_THAT(root, op::ShapeWithLayout(ShapeUtil::MakeTupleShape(
485 {expected_shape, expected_shape})));
486 EXPECT_THAT(root, op::ReplicaGroups({{0, 1}}));
487 }
488
TEST_F(XlaBuilderTest,CollectivePermute)489 TEST_F(XlaBuilderTest, CollectivePermute) {
490 XlaBuilder b(TestName());
491 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
492 CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}});
493 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
494 auto root = module->entry_computation()->root_instruction();
495 EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute);
496 }
497
TEST_F(XlaBuilderTest,GetDimensionSize)498 TEST_F(XlaBuilderTest, GetDimensionSize) {
499 XlaBuilder b(TestName());
500 auto x =
501 Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
502 GetDimensionSize(x, 1);
503 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
504 auto root = module->entry_computation()->root_instruction();
505 EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
506 }
507
TEST_F(XlaBuilderTest,GetDimensionSizeConstant)508 TEST_F(XlaBuilderTest, GetDimensionSizeConstant) {
509 XlaBuilder b(TestName());
510 auto x =
511 Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
512 // Get dimension size from a contant dimension gives us a constant.
513 GetDimensionSize(x, 0);
514 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
515 auto root = module->entry_computation()->root_instruction();
516 EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
517 }
518
TEST_F(XlaBuilderTest,ReportError)519 TEST_F(XlaBuilderTest, ReportError) {
520 XlaBuilder b(TestName());
521 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
522 Add(b.ReportError(InvalidArgument("a test error")), x);
523 auto statusor = b.Build();
524 ASSERT_FALSE(statusor.ok());
525 EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
526 }
527
TEST_F(XlaBuilderTest,ReportErrorOrReturnHandlesNonErrors)528 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) {
529 XlaBuilder b(TestName());
530 StatusOr<XlaOp> op(ConstantR0<float>(&b, 1.0));
531 Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
532 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
533 auto root = module->entry_computation()->root_instruction();
534 EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
535 }
536
TEST_F(XlaBuilderTest,ReportErrorOrReturnHandlesErrors)537 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
538 XlaBuilder b(TestName());
539 StatusOr<XlaOp> op(InvalidArgument("a test error"));
540 Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
541 auto statusor = b.Build();
542 ASSERT_FALSE(statusor.ok());
543 EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
544 }
545
TEST_F(XlaBuilderTest,BuildWithSpecificRoot)546 TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
547 XlaBuilder b(TestName());
548 XlaOp constant = ConstantR0<float>(&b, 1.0);
549 Add(constant, ConstantR0<float>(&b, 2.0));
550 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
551 auto root = module->entry_computation()->root_instruction();
552 EXPECT_THAT(root, op::Constant());
553 }
554
TEST_F(XlaBuilderTest,BuildWithSpecificRootAndMultipleParameters)555 TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
556 // Specifying a particular root in Build should still include all entry
557 // parameters.
558 XlaBuilder b(TestName());
559 const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
560 XlaOp x = Parameter(&b, 0, shape, "x");
561 XlaOp y = Parameter(&b, 1, shape, "y");
562 XlaOp z = Parameter(&b, 2, shape, "z");
563 Add(x, Sub(y, z));
564 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
565 auto root = module->entry_computation()->root_instruction();
566 EXPECT_THAT(root, op::Parameter());
567 EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
568 EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
569 }
570
TEST_F(XlaBuilderTest,BuildWithSpecificRootWithWrongBuilder)571 TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
572 XlaBuilder b(TestName());
573 XlaBuilder other_b(TestName());
574 const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
575
576 Parameter(&b, 0, shape, "param");
577 XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
578
579 Status status = b.Build(other_param).status();
580 ASSERT_IS_NOT_OK(status);
581 EXPECT_THAT(
582 status.error_message(),
583 ::testing::HasSubstr("root operation is not in this computation"));
584 }
585
TEST_F(XlaBuilderTest,ProtoMatches)586 TEST_F(XlaBuilderTest, ProtoMatches) {
587 std::vector<XlaComputation> computations;
588 const int n = 2;
589 computations.reserve(n);
590 for (int i = 0; i < n; ++i) {
591 XlaBuilder b_call("the_only_to_apply");
592 auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
593 auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
594 Add(p0, Add(p1, p0));
595 TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
596 XlaBuilder b(TestName());
597 auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
598 auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
599 auto one = ConstantR0<float>(&b, 1);
600 auto two = ConstantR0<float>(&b, 2);
601 Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
602 computations.push_back(b.Build().ValueOrDie());
603 }
604 auto c0_string = computations[0].proto().SerializeAsString();
605 auto c1_string = computations[1].proto().SerializeAsString();
606 EXPECT_EQ(c0_string, c1_string);
607 }
608
TEST_F(XlaBuilderTest,DynamicParameter)609 TEST_F(XlaBuilderTest, DynamicParameter) {
610 XlaBuilder b(TestName());
611 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
612 {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})});
613 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
614 Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1");
615 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1,
616 /*dynamic_size_param_index=*/{},
617 /*target_param_num=*/0,
618 /*target_param_index=*/{1},
619 /*target_dim_num=*/0));
620 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0));
621 const Shape& param_shape = module->entry_computation()
622 ->parameter_instruction(0)
623 ->shape()
624 .tuple_shapes(1);
625 EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
626 }
627
TEST_F(XlaBuilderTest,SetDimensionSize)628 TEST_F(XlaBuilderTest, SetDimensionSize) {
629 XlaBuilder b(TestName());
630 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
631 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
632 auto set_dim_size = SetDimensionSize(p0, p1, 0);
633 TF_ASSERT_OK_AND_ASSIGN(auto module,
634 BuildHloModule(&b, /*root=*/set_dim_size));
635 const Shape& root_shape =
636 module->entry_computation()->root_instruction()->shape();
637 EXPECT_TRUE(root_shape.is_dynamic_dimension(0));
638 }
639
TEST_F(XlaBuilderTest,RemoveDynamicDimension)640 TEST_F(XlaBuilderTest, RemoveDynamicDimension) {
641 XlaBuilder b(TestName());
642 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
643 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
644 auto set_dim_size = SetDimensionSize(p0, p1, 0);
645 auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0);
646 TF_ASSERT_OK_AND_ASSIGN(auto module,
647 BuildHloModule(&b, /*root=*/remove_dim_size));
648 const Shape& root_shape =
649 module->entry_computation()->root_instruction()->shape();
650 // Dynamic dimension has been removed.
651 EXPECT_FALSE(root_shape.is_dynamic_dimension(0));
652 }
653
TEST_F(XlaBuilderTest,RemoveDynamicDimensionMultiDims)654 TEST_F(XlaBuilderTest, RemoveDynamicDimensionMultiDims) {
655 XlaBuilder b(TestName());
656 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 10}), "p0");
657 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
658 auto set_dim_size = SetDimensionSize(p0, p1, 0);
659 set_dim_size = SetDimensionSize(set_dim_size, p1, 1);
660 auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0);
661 remove_dim_size = RemoveDynamicDimension(remove_dim_size, 1);
662 TF_ASSERT_OK_AND_ASSIGN(auto module,
663 BuildHloModule(&b, /*root=*/remove_dim_size));
664 const Shape& root_shape =
665 module->entry_computation()->root_instruction()->shape();
666 // Dynamic dimensions are removed.
667 EXPECT_FALSE(root_shape.is_dynamic_dimension(0));
668 EXPECT_FALSE(root_shape.is_dynamic_dimension(1));
669 }
670
TEST_F(XlaBuilderTest,DynamicUnary)671 TEST_F(XlaBuilderTest, DynamicUnary) {
672 XlaBuilder b(TestName());
673 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
674 {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
675 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
676 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
677 /*dynamic_size_param_index=*/{1},
678 /*target_param_num=*/0,
679 /*target_param_index=*/{0},
680 /*target_dim_num=*/0));
681 auto gte = GetTupleElement(p0, 0);
682 Neg(gte);
683 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
684 const Shape& result_shape =
685 module->entry_computation()->root_instruction()->shape();
686 EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
687 }
688
TEST_F(XlaBuilderTest,DynamicBinary)689 TEST_F(XlaBuilderTest, DynamicBinary) {
690 XlaBuilder b(TestName());
691 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
692 {ShapeUtil::MakeShape(F32, {5}, {true}),
693 ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
694 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
695 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
696 /*dynamic_size_param_index=*/{2},
697 /*target_param_num=*/0,
698 /*target_param_index=*/{0},
699 /*target_dim_num=*/0));
700 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
701 /*dynamic_size_param_index=*/{2},
702 /*target_param_num=*/0,
703 /*target_param_index=*/{1},
704 /*target_dim_num=*/0));
705 auto gte0 = GetTupleElement(p0, 0);
706 auto gte1 = GetTupleElement(p0, 1);
707 Add(gte0, gte1);
708 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
709 const Shape& result_shape =
710 module->entry_computation()->root_instruction()->shape();
711 EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
712 }
713
TEST_F(XlaBuilderTest,DynamicBinaryHasBroadcast)714 TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) {
715 XlaBuilder b(TestName());
716 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
717 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
718 ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
719 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
720 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
721 /*dynamic_size_param_index=*/{2},
722 /*target_param_num=*/0,
723 /*target_param_index=*/{0},
724 /*target_dim_num=*/0));
725 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
726 /*dynamic_size_param_index=*/{2},
727 /*target_param_num=*/0,
728 /*target_param_index=*/{1},
729 /*target_dim_num=*/0));
730 auto gte0 = GetTupleElement(p0, 0);
731 auto gte1 = GetTupleElement(p0, 1);
732 Add(gte0, gte1, {0});
733 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
734 const Shape& result_shape =
735 module->entry_computation()->root_instruction()->shape();
736 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
737 << result_shape;
738 }
739
TEST_F(XlaBuilderTest,DynamicBroadcast)740 TEST_F(XlaBuilderTest, DynamicBroadcast) {
741 XlaBuilder b(TestName());
742 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
743 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
744 ShapeUtil::MakeShape(U32, {})});
745 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
746 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
747 /*dynamic_size_param_index=*/{1},
748 /*target_param_num=*/0,
749 /*target_param_index=*/{0},
750 /*target_dim_num=*/0));
751 auto gte = GetTupleElement(p0, 0);
752 BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4},
753 /*broadcast_dimensions=*/{1, 2});
754 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
755 const Shape& result_shape =
756 module->entry_computation()->root_instruction()->shape();
757 EXPECT_TRUE(
758 ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
759 << result_shape;
760 }
761
TEST_F(XlaBuilderTest,DynamicBinaryHasDegenerateBroadcast)762 TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) {
763 XlaBuilder b(TestName());
764 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
765 {ShapeUtil::MakeShape(F32, {10}, {true}),
766 ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})});
767 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
768 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
769 /*dynamic_size_param_index=*/{1},
770 /*target_param_num=*/0,
771 /*target_param_index=*/{0},
772 /*target_dim_num=*/0));
773 auto gte0 = GetTupleElement(p0, 0);
774 auto gte1 = GetTupleElement(p0, 1);
775 Add(gte0, gte1, /*broadcast_dimensions=*/{0}); // f32[<=10, 15]
776 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
777 const Shape& result_shape =
778 module->entry_computation()->root_instruction()->shape();
779 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
780 << result_shape;
781 }
782
TEST_F(XlaBuilderTest,DynamicSelectOnlyPredDynamic)783 TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) {
784 XlaBuilder b(TestName());
785 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
786 {ShapeUtil::MakeShape(PRED, {10}, {true}),
787 ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})});
788 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
789 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
790 /*dynamic_size_param_index=*/{1},
791 /*target_param_num=*/0,
792 /*target_param_index=*/{0},
793 /*target_dim_num=*/0));
794 auto gte0 = GetTupleElement(p0, 0);
795 auto gte1 = GetTupleElement(p0, 1);
796
797 Select(gte0, gte1, gte1);
798
799 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
800 const Shape& result_shape =
801 module->entry_computation()->root_instruction()->shape();
802 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true}))
803 << result_shape;
804 }
805
TEST_F(XlaBuilderTest,SelectIntoConditional)806 TEST_F(XlaBuilderTest, SelectIntoConditional) {
807 XlaBuilder b(TestName());
808 Shape selector_shape = ShapeUtil::MakeShape(PRED, {});
809 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
810 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})});
811 XlaOp p0 = Parameter(&b, 0, selector_shape, "p0");
812 XlaOp p1 = Parameter(&b, 1, tuple_param_shape, "p1");
813 XlaOp p2 = Parameter(&b, 2, tuple_param_shape, "p2");
814
815 Select(p0, p1, p2);
816
817 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
818 BuildHloModule(&b));
819 EXPECT_THAT(
820 module->entry_computation()->root_instruction(),
821 op::Conditional(op::Parameter(0), op::Parameter(1), op::Parameter(2)));
822 EXPECT_THAT(module->entry_computation()
823 ->root_instruction()
824 ->branch_computation(0)
825 ->root_instruction(),
826 op::Parameter(0));
827 EXPECT_THAT(module->entry_computation()
828 ->root_instruction()
829 ->branch_computation(1)
830 ->root_instruction(),
831 op::Parameter(0));
832 }
833
TEST_F(XlaBuilderTest,DynamicPad)834 TEST_F(XlaBuilderTest, DynamicPad) {
835 XlaBuilder b(TestName());
836 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
837 {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
838 ShapeUtil::MakeShape(U32, {})});
839 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
840 auto pad_val = ConstantR0<float>(&b, -1);
841 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
842 /*dynamic_size_param_index=*/{1},
843 /*target_param_num=*/0,
844 /*target_param_index=*/{0},
845 /*target_dim_num=*/0));
846 auto gte = GetTupleElement(p0, 0);
847 PaddingConfig padding_config;
848 for (int i = 0; i < 2; i++) {
849 auto dimension = padding_config.add_dimensions();
850 dimension->set_edge_padding_low(0);
851 dimension->set_edge_padding_high(0);
852 dimension->set_interior_padding(0);
853 }
854 Pad(gte, pad_val, padding_config);
855 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
856 const Shape& result_shape =
857 module->entry_computation()->root_instruction()->shape();
858 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
859 << result_shape;
860 }
861
TEST_F(XlaBuilderTest,DynamicConvolution)862 TEST_F(XlaBuilderTest, DynamicConvolution) {
863 XlaBuilder b(TestName());
864 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
865 {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}),
866 ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}),
867 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
868 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
869 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
870 /*dynamic_size_param_index=*/{2},
871 /*target_param_num=*/0,
872 /*target_param_index=*/{0},
873 /*target_dim_num=*/0));
874 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
875 /*dynamic_size_param_index=*/{3},
876 /*target_param_num=*/0,
877 /*target_param_index=*/{1},
878 /*target_dim_num=*/2));
879 auto input = GetTupleElement(p0, 0);
880 auto filter = GetTupleElement(p0, 1);
881 ConvolutionDimensionNumbers dnums;
882 dnums.set_input_batch_dimension(0);
883 dnums.set_output_batch_dimension(0);
884 dnums.add_input_spatial_dimensions(1);
885 dnums.add_output_spatial_dimensions(1);
886 dnums.add_input_spatial_dimensions(2);
887 dnums.add_output_spatial_dimensions(2);
888 dnums.set_input_feature_dimension(3);
889 dnums.set_output_feature_dimension(3);
890 dnums.add_kernel_spatial_dimensions(0);
891 dnums.add_kernel_spatial_dimensions(1);
892 dnums.set_kernel_input_feature_dimension(2);
893 dnums.set_kernel_output_feature_dimension(3);
894 ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
895 /*feature_group_count=*/1);
896 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
897 const Shape& result_shape =
898 module->entry_computation()->root_instruction()->shape();
899 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
900 {true, false, false, false}))
901 << result_shape;
902 }
903
TEST_F(XlaBuilderTest,DynamicDot)904 TEST_F(XlaBuilderTest, DynamicDot) {
905 XlaBuilder b(TestName());
906 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
907 {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}),
908 ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}),
909 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
910 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
911 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
912 /*dynamic_size_param_index=*/{2},
913 /*target_param_num=*/0,
914 /*target_param_index=*/{0},
915 /*target_dim_num=*/0));
916 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
917 /*dynamic_size_param_index=*/{2},
918 /*target_param_num=*/0,
919 /*target_param_index=*/{1},
920 /*target_dim_num=*/0));
921 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
922 /*dynamic_size_param_index=*/{3},
923 /*target_param_num=*/0,
924 /*target_param_index=*/{0},
925 /*target_dim_num=*/1));
926
927 auto lhs = GetTupleElement(p0, 0);
928 auto rhs = GetTupleElement(p0, 1);
929 DotDimensionNumbers dnums;
930 dnums.add_lhs_contracting_dimensions(2);
931 dnums.add_rhs_contracting_dimensions(1);
932 dnums.add_lhs_batch_dimensions(0);
933 dnums.add_rhs_batch_dimensions(0);
934 DotGeneral(lhs, rhs, dnums);
935 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
936 const Shape& result_shape =
937 module->entry_computation()->root_instruction()->shape();
938 EXPECT_TRUE(
939 ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false}))
940 << result_shape;
941 }
942
TEST_F(XlaBuilderTest,DynamicReduce)943 TEST_F(XlaBuilderTest, DynamicReduce) {
944 XlaBuilder b(TestName());
945 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
946 {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}),
947 ShapeUtil::MakeShape(U32, {})});
948 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
949 auto init = ConstantR0<float>(&b, 0);
950 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
951 /*dynamic_size_param_index=*/{1},
952 /*target_param_num=*/0,
953 /*target_param_index=*/{0},
954 /*target_dim_num=*/1));
955 auto gte = GetTupleElement(p0, 0);
956 XlaBuilder bsum(TestName());
957 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
958 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
959 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
960 Reduce(gte, init, sum, {0});
961 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
962 const Shape& result_shape =
963 module->entry_computation()->root_instruction()->shape();
964 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
965 << result_shape;
966 }
967
TEST_F(XlaBuilderTest,DynamicReduceWindow)968 TEST_F(XlaBuilderTest, DynamicReduceWindow) {
969 XlaBuilder b(TestName());
970 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
971 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
972 ShapeUtil::MakeShape(U32, {})});
973 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
974 auto init = ConstantR0<float>(&b, 0.f);
975 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
976 /*dynamic_size_param_index=*/{1},
977 /*target_param_num=*/0,
978 /*target_param_index=*/{0},
979 /*target_dim_num=*/0));
980 auto gte = GetTupleElement(p0, 0);
981 XlaBuilder bsum(TestName());
982 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
983 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
984 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
985 ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4},
986 /*window_strides=*/{1, 1, 1}, Padding::kValid);
987 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
988 VLOG(2) << module->entry_computation()->root_instruction()->ToString()
989 << "\n";
990 const Shape& result_shape =
991 module->entry_computation()->root_instruction()->shape();
992 EXPECT_TRUE(
993 ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
994 << result_shape;
995 }
996
TEST_F(XlaBuilderTest,VariadicDynamicReduceWindow)997 TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) {
998 XlaBuilder b(TestName());
999 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1000 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
1001 ShapeUtil::MakeShape(U32, {})});
1002 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1003 auto p1 = Parameter(&b, 1, tuple_param_shape, "p1");
1004 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1005 /*dynamic_size_param_index=*/{1},
1006 /*target_param_num=*/0,
1007 /*target_param_index=*/{0},
1008 /*target_dim_num=*/0));
1009 auto gte0 = GetTupleElement(p0, 0);
1010 auto gte1 = GetTupleElement(p1, 0);
1011 std::vector<XlaOp> input_operands = {gte0, gte1};
1012 XlaBuilder bsum(TestName());
1013 auto p2 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x0");
1014 auto p3 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "x1");
1015 auto p4 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "y0");
1016 auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1");
1017 std::vector<XlaOp> output_operands = {Add(p2, p4), Add(p3, p5)};
1018 Tuple(&bsum, absl::MakeSpan(output_operands));
1019 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
1020 auto init = ConstantR0<float>(&b, 0.f);
1021 ReduceWindow(input_operands, {init, init}, sum,
1022 /*window_dimensions=*/{1, 2, 4},
1023 /*window_strides=*/{1, 1, 1}, Padding::kValid);
1024 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1025 VLOG(2) << module->entry_computation()->root_instruction()->ToString()
1026 << "\n";
1027 const Shape& result_shape =
1028 module->entry_computation()->root_instruction()->shape();
1029 EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(0).dynamic_dimensions(),
1030 {true, false, false}))
1031 << result_shape.tuple_shapes(0);
1032 EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(1).dynamic_dimensions(),
1033 {true, false, false}))
1034 << result_shape.tuple_shapes(1);
1035 }
1036
TEST_F(XlaBuilderTest,DynamicSelectAndScatter)1037 TEST_F(XlaBuilderTest, DynamicSelectAndScatter) {
1038 XlaBuilder b(TestName());
1039 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1040 {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
1041 ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}),
1042 ShapeUtil::MakeShape(U32, {})});
1043 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1044 auto init = ConstantR0<float>(&b, 0.f);
1045 XlaBuilder bsum(TestName());
1046 Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
1047 Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
1048 TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
1049 XlaBuilder bge(TestName());
1050 Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"),
1051 Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y"));
1052 TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build());
1053
1054 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1055 /*dynamic_size_param_index=*/{2},
1056 /*target_param_num=*/0,
1057 /*target_param_index=*/{0},
1058 /*target_dim_num=*/0));
1059 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1060 /*dynamic_size_param_index=*/{2},
1061 /*target_param_num=*/0,
1062 /*target_param_index=*/{1},
1063 /*target_dim_num=*/0));
1064 auto gte0 = GetTupleElement(p0, 0);
1065 auto source = GetTupleElement(p0, 1);
1066 SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source,
1067 init, sum);
1068 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1069 const Shape& result_shape =
1070 module->entry_computation()->root_instruction()->shape();
1071 EXPECT_TRUE(
1072 ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
1073 << result_shape;
1074 }
1075
TEST_F(XlaBuilderTest,DynamicReshape)1076 TEST_F(XlaBuilderTest, DynamicReshape) {
1077 XlaBuilder b(TestName());
1078 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1079 {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6},
1080 {false, false, true, true, false}),
1081 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
1082 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1083 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1084 /*dynamic_size_param_index=*/{1},
1085 /*target_param_num=*/0,
1086 /*target_param_index=*/{0},
1087 /*target_dim_num=*/2));
1088 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1089 /*dynamic_size_param_index=*/{2},
1090 /*target_param_num=*/0,
1091 /*target_param_index=*/{0},
1092 /*target_dim_num=*/3));
1093 auto gte = GetTupleElement(p0, 0); // f32[2, 3, <=4, <=5, 6]
1094 Reshape(gte, /*new_sizes=*/{6, 4, 5, 2, 3});
1095 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1096 const Shape& result_shape =
1097 module->entry_computation()->root_instruction()->shape();
1098 EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
1099 EXPECT_TRUE(result_shape.is_dynamic_dimension(2));
1100 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
1101 {false, true, true, false, false}))
1102 << result_shape;
1103 }
1104
TEST_F(XlaBuilderTest,DynamicSelect)1105 TEST_F(XlaBuilderTest, DynamicSelect) {
1106 XlaBuilder b(TestName());
1107 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1108 {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
1109 ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
1110 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
1111 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1112 auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
1113 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1114 /*dynamic_size_param_index=*/{2},
1115 /*target_param_num=*/0,
1116 /*target_param_index=*/{0},
1117 /*target_dim_num=*/1));
1118 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1119 /*dynamic_size_param_index=*/{3},
1120 /*target_param_num=*/0,
1121 /*target_param_index=*/{1},
1122 /*target_dim_num=*/1));
1123 auto gte0 = GetTupleElement(p0, 0);
1124 auto gte1 = GetTupleElement(p0, 1);
1125 Select(pred, gte0, gte1);
1126 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1127 const Shape& result_shape =
1128 module->entry_computation()->root_instruction()->shape();
1129 EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
1130 EXPECT_FALSE(result_shape.is_dynamic_dimension(2));
1131 EXPECT_TRUE(
1132 ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
1133 << result_shape;
1134 }
1135
TEST_F(XlaBuilderTest,DynamicSelectNotCompatible)1136 TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
1137 XlaBuilder b(TestName());
1138 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1139 {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
1140 ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}),
1141 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
1142 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1143 auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
1144 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1145 /*dynamic_size_param_index=*/{2},
1146 /*target_param_num=*/0,
1147 /*target_param_index=*/{0},
1148 /*target_dim_num=*/1));
1149 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1150 /*dynamic_size_param_index=*/{3},
1151 /*target_param_num=*/0,
1152 /*target_param_index=*/{1},
1153 /*target_dim_num=*/2));
1154 auto gte0 = GetTupleElement(p0, 0); // f32[4,<=5,6]
1155 auto gte1 = GetTupleElement(p0, 1); // f32[4,5,<=6]
1156 Select(pred, gte0, gte1);
1157 Status status = BuildHloModule(&b).status();
1158 ASSERT_IS_OK(status);
1159 }
1160
TEST_F(XlaBuilderTest,DynamicTranspose)1161 TEST_F(XlaBuilderTest, DynamicTranspose) {
1162 XlaBuilder b(TestName());
1163 Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1164 {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}),
1165 ShapeUtil::MakeShape(U32, {})});
1166 auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1167 ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1168 /*dynamic_size_param_index=*/{1},
1169 /*target_param_num=*/0,
1170 /*target_param_index=*/{0},
1171 /*target_dim_num=*/0));
1172 auto gte = GetTupleElement(p0, 0);
1173 Transpose(gte, /*permutation=*/{1, 0});
1174 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1175 const Shape& result_shape =
1176 module->entry_computation()->root_instruction()->shape();
1177 EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true}))
1178 << result_shape;
1179 }
1180
TEST_F(XlaBuilderTest,DotWithPreferredElementType)1181 TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
1182 XlaBuilder b(TestName());
1183 Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
1184 Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
1185 auto p0 = Parameter(&b, 0, p0_shape, "p0");
1186 auto p1 = Parameter(&b, 1, p1_shape, "p1");
1187
1188 DotDimensionNumbers dnums;
1189 dnums.add_lhs_contracting_dimensions(1);
1190 dnums.add_rhs_contracting_dimensions(0);
1191 DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
1192 /*preferred_element_type=*/U32);
1193 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1194 const Shape& result_shape =
1195 module->entry_computation()->root_instruction()->shape();
1196 ASSERT_TRUE(
1197 ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
1198 }
1199
TEST_F(XlaBuilderTest,ConvolutionWithPreferredElementType)1200 TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
1201 XlaBuilder b(TestName());
1202 Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
1203 Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
1204 auto p0 = Parameter(&b, 0, p0_shape, "p0");
1205 auto p1 = Parameter(&b, 1, p1_shape, "p1");
1206
1207 ConvolutionDimensionNumbers dnums;
1208 dnums.set_input_batch_dimension(0);
1209 dnums.set_output_batch_dimension(0);
1210 dnums.add_input_spatial_dimensions(1);
1211 dnums.add_output_spatial_dimensions(1);
1212 dnums.add_input_spatial_dimensions(2);
1213 dnums.add_output_spatial_dimensions(2);
1214 dnums.set_input_feature_dimension(3);
1215 dnums.set_output_feature_dimension(3);
1216 dnums.add_kernel_spatial_dimensions(0);
1217 dnums.add_kernel_spatial_dimensions(1);
1218 dnums.set_kernel_input_feature_dimension(2);
1219 dnums.set_kernel_output_feature_dimension(3);
1220 ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
1221 /*feature_group_count=*/1, /*batch_group_count=*/1,
1222 /*precision_config=*/nullptr,
1223 /*preferred_element_type=*/S32);
1224 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1225 const Shape& result_shape =
1226 module->entry_computation()->root_instruction()->shape();
1227 ASSERT_TRUE(
1228 ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
1229 }
1230
TEST_F(XlaBuilderTest,AfterAllWithNonTokenOperands)1231 TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
1232 XlaBuilder b(TestName());
1233 AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
1234 Status status = b.Build().status();
1235 ASSERT_IS_NOT_OK(status);
1236 EXPECT_THAT(status.error_message(),
1237 ::testing::HasSubstr("All operands to AfterAll must be tokens"));
1238 }
1239
TEST_F(XlaBuilderTest,CheckInputOutputAlias)1240 TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
1241 XlaBuilder b(TestName());
1242 auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0");
1243 auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1");
1244 auto add = Add(p0, p1);
1245 auto sub = Sub(p0, p1);
1246 auto root = Tuple(&b, {add, sub});
1247
1248 b.SetUpAlias({1}, 0, {});
1249 b.SetUpAlias({0}, 1, {});
1250
1251 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root));
1252
1253 const HloInputOutputAliasConfig& config = module->input_output_alias_config();
1254 EXPECT_TRUE(config.ParameterHasAlias(0, {}));
1255 EXPECT_TRUE(config.ParameterHasAlias(1, {}));
1256
1257 auto alias_p0 = config.GetAliasedOutput(0, {});
1258 ASSERT_TRUE(alias_p0.has_value());
1259 EXPECT_EQ(*alias_p0, ShapeIndex({1}));
1260
1261 auto alias_p1 = config.GetAliasedOutput(1, {});
1262 ASSERT_TRUE(alias_p1.has_value());
1263 EXPECT_EQ(*alias_p1, ShapeIndex({0}));
1264 }
1265
ExpectAttributesMatch(const FrontendAttributes & attr,const FrontendAttributes & ref)1266 void ExpectAttributesMatch(const FrontendAttributes& attr,
1267 const FrontendAttributes& ref) {
1268 EXPECT_EQ(ref.map_size(), attr.map_size());
1269 for (auto reference : ref.map()) {
1270 auto other = attr.map().find(reference.first);
1271 EXPECT_NE(other, attr.map().end());
1272 EXPECT_EQ(other->second, reference.second);
1273 }
1274 }
1275
ExpectInstructionsAttributesMatch(const HloModule & module,const std::vector<FrontendAttributes> & expected)1276 void ExpectInstructionsAttributesMatch(
1277 const HloModule& module, const std::vector<FrontendAttributes>& expected) {
1278 ASSERT_EQ(module.computation_count(), 1);
1279 auto expected_it = expected.begin();
1280 for (auto inst : module.entry_computation()->instructions()) {
1281 ASSERT_NE(expected_it, expected.end());
1282 ExpectAttributesMatch(inst->frontend_attributes(), *expected_it);
1283 expected_it++;
1284 }
1285 EXPECT_EQ(expected_it, expected.end());
1286 }
1287
TEST_F(XlaBuilderTest,SimpleSetFrontendAttributes)1288 TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) {
1289 XlaBuilder b(TestName());
1290 FrontendAttributes attributes;
1291
1292 ConstantR0(&b, 0); // No attribute set
1293
1294 (*attributes.mutable_map())["attr_a"] = "a";
1295 b.SetFrontendAttributes(attributes);
1296 ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
1297
1298 b.ClearFrontendAttributes();
1299 ConstantR0(&b, 0); // No attribute set
1300
1301 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1302
1303 std::vector<FrontendAttributes> expected{FrontendAttributes(), attributes,
1304 FrontendAttributes()};
1305 ExpectInstructionsAttributesMatch(*module, expected);
1306 }
1307
TEST_F(XlaBuilderTest,ComplexSetFrontendAttributes)1308 TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) {
1309 XlaBuilder b(TestName());
1310
1311 ConstantR0(&b, 0); // No attribute set.
1312 std::vector<FrontendAttributes> expected{FrontendAttributes()};
1313
1314 {
1315 FrontendAttributes attributes;
1316 (*attributes.mutable_map())["attr_a"] = "a";
1317 b.SetFrontendAttributes(attributes);
1318 ConstantR0(&b, 0); // One attribute: { "attr_a": "a" }
1319 expected.push_back(attributes);
1320 }
1321
1322 {
1323 FrontendAttributes attributes;
1324 (*attributes.mutable_map())["attr_b"] = "b";
1325 b.SetFrontendAttributes(attributes);
1326 ConstantR0(&b, 0); // One attribute: { "attr_b": "b" }
1327 expected.push_back(attributes);
1328 }
1329
1330 {
1331 FrontendAttributes attributes;
1332 (*attributes.mutable_map())["attr_b"] = "b";
1333 (*attributes.mutable_map())["attr_c"] = "c";
1334 b.SetFrontendAttributes(attributes);
1335 ConstantR0(&b, 0); // Two attributes: { "attr_b": "b", "attr_c": "c" }
1336 expected.push_back(attributes);
1337 }
1338
1339 b.ClearFrontendAttributes();
1340 ConstantR0(&b, 0); // No attribute set
1341 expected.push_back(FrontendAttributes());
1342
1343 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1344 ExpectInstructionsAttributesMatch(*module, expected);
1345 }
1346
TEST_F(XlaBuilderTest,AddFrontendAttribute)1347 TEST_F(XlaBuilderTest, AddFrontendAttribute) {
1348 XlaBuilder b(TestName());
1349
1350 ConstantR0(&b, 0);
1351 std::vector<FrontendAttributes> expected{FrontendAttributes()};
1352
1353 // One attribute: { "attr_a": "a" }
1354 {
1355 FrontendAttributes attributes;
1356 (*attributes.mutable_map())["attr_a"] = "a";
1357 b.SetFrontendAttributes(attributes);
1358 ConstantR0(&b, 0);
1359 expected.push_back(attributes);
1360 }
1361
1362 // Two attributes: {"attra": "a", "attr_c": "c"}
1363 {
1364 auto op = ConstantR0(&b, 0);
1365 EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c"));
1366
1367 FrontendAttributes attributes;
1368 (*attributes.mutable_map())["attr_a"] = "a";
1369 (*attributes.mutable_map())["attr_c"] = "c";
1370 expected.push_back(attributes);
1371 }
1372
1373 // Override value of existing "attr_a"
1374 // One attribute: { "attr_a", "a2"}
1375 {
1376 auto op = ConstantR0(&b, 0);
1377 EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2"));
1378 FrontendAttributes attributes;
1379 (*attributes.mutable_map())["attr_a"] = "a2";
1380 expected.push_back(attributes);
1381 }
1382
1383 // Check "attr_a" is back to its original value
1384 // One attribute: { "attr_a", "a"}
1385 {
1386 auto op = ConstantR0(&b, 0);
1387 (void)op;
1388 FrontendAttributes attributes;
1389 (*attributes.mutable_map())["attr_a"] = "a";
1390 expected.push_back(attributes);
1391 }
1392
1393 b.ClearFrontendAttributes();
1394 ConstantR0(&b, 0); // No attribute set
1395 expected.push_back(FrontendAttributes());
1396
1397 // One attribute: { "attr_d", "d"}
1398 {
1399 auto op = ConstantR0(&b, 0);
1400 EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d"));
1401 FrontendAttributes attributes;
1402 (*attributes.mutable_map())["attr_d"] = "d";
1403 expected.push_back(attributes);
1404 }
1405
1406 ConstantR0(&b, 0); // No attribute set
1407 expected.push_back(FrontendAttributes());
1408
1409 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1410 ExpectInstructionsAttributesMatch(*module, expected);
1411 }
1412
TEST_F(XlaBuilderTest,ComparisonType)1413 TEST_F(XlaBuilderTest, ComparisonType) {
1414 XlaBuilder b(TestName());
1415 (void)Le(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
1416 TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1417 auto root = module->entry_computation()->root_instruction();
1418 ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant()));
1419 EXPECT_EQ(Comparison::Type::kSigned,
1420 DynCast<HloCompareInstruction>(root)->type());
1421 }
1422
TEST_F(XlaBuilderTest,StableLookUpInstructionByHandle)1423 TEST_F(XlaBuilderTest, StableLookUpInstructionByHandle) {
1424 XlaBuilder b(TestName());
1425 internal::XlaBuilderFriend builder_friend;
1426 XlaOp le = Le(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
1427 HloInstructionProto* first_op = builder_friend.GetInstruction(le);
1428 // Create some more instructions.
1429 for (int i = 0; i < 100; ++i) {
1430 (void)Le(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
1431 }
1432 // Make sure first_op hasn't changed.
1433 HloInstructionProto* first_op_now = builder_friend.GetInstruction(le);
1434 EXPECT_EQ(first_op, first_op_now);
1435 }
1436
TEST_F(XlaBuilderTest,ComplexAbsConstant)1437 TEST_F(XlaBuilderTest, ComplexAbsConstant) {
1438 XlaBuilder b(TestName());
1439 XlaOp out =
1440 Abs(ConstantR0<std::complex<float>>(&b, std::complex<float>{-1, -1}));
1441 ValueInference value_inference(&b);
1442 StatusOr<OptionalLiteral> analyzed =
1443 value_inference.AnalyzeConstant(out, kUpperBound);
1444 EXPECT_IS_OK(analyzed.status());
1445 EXPECT_EQ(analyzed->GetValue().value().shape().element_type(),
1446 PrimitiveType::F32);
1447 }
1448
1449 } // namespace
1450 } // namespace xla
1451