xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/xla_builder_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/xla_builder.h"
17 
18 #include <string>
19 
20 #include "tensorflow/compiler/xla/client/value_inference.h"
21 #include "tensorflow/compiler/xla/client/xla_computation.h"
22 #include "tensorflow/compiler/xla/debug_options_flags.h"
23 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 
34 namespace xla {
35 
36 namespace {
37 
38 namespace op = xla::testing::opcode_matchers;
39 
40 using ::testing::HasSubstr;
41 
42 // TODO(b/74197823): Move the tests to service/.
43 class XlaBuilderTest : public ::testing::Test {
44  protected:
BuildHloModule(XlaBuilder * b)45   StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b) {
46     TF_ASSIGN_OR_RETURN(XlaComputation computation,
47                         b->Build(/*remove_dynamic_dimensions=*/false));
48     const HloModuleProto& proto = computation.proto();
49     TF_ASSIGN_OR_RETURN(const auto& config,
50                         HloModule::CreateModuleConfigFromProto(
51                             proto, GetDebugOptionsFromFlags()));
52     return HloModule::CreateFromProto(proto, config);
53   }
54 
55   // Overload which explicitly specifies the root instruction.
BuildHloModule(XlaBuilder * b,XlaOp root)56   StatusOr<std::unique_ptr<HloModule>> BuildHloModule(XlaBuilder* b,
57                                                       XlaOp root) {
58     TF_ASSIGN_OR_RETURN(XlaComputation computation,
59                         b->Build(root, /*remove_dynamic_dimensions=*/false));
60     const HloModuleProto& proto = computation.proto();
61     TF_ASSIGN_OR_RETURN(const auto& config,
62                         HloModule::CreateModuleConfigFromProto(
63                             proto, GetDebugOptionsFromFlags()));
64     return HloModule::CreateFromProto(proto, config);
65   }
66 
67   // Returns the name of the test currently being run.
TestName() const68   std::string TestName() const {
69     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
70   }
71 };
72 
TEST_F(XlaBuilderTest,OnePlusTwo)73 TEST_F(XlaBuilderTest, OnePlusTwo) {
74   XlaBuilder b(TestName());
75   Add(ConstantR0<float>(&b, 1.0), ConstantR0<float>(&b, 2.0));
76   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
77   auto root = module->entry_computation()->root_instruction();
78   EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
79 }
80 
TEST_F(XlaBuilderTest,UnaryOperatorsBuildExpectedHLO)81 TEST_F(XlaBuilderTest, UnaryOperatorsBuildExpectedHLO) {
82   auto test_unary_operator =
83       [&](std::function<XlaOp(XlaOp)> op,
84           ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
85         XlaBuilder b(TestName());
86         op(ConstantR0<int32_t>(&b, 1));
87         TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
88         auto root = module->entry_computation()->root_instruction();
89         EXPECT_THAT(root, matches_pattern);
90       };
91   test_unary_operator([](XlaOp x) { return -x; }, op::Negate(op::Constant()));
92   test_unary_operator([](XlaOp x) { return ~x; }, op::Not(op::Constant()));
93 }
94 
TEST_F(XlaBuilderTest,BinaryOperatorsBuildExpectedHLO)95 TEST_F(XlaBuilderTest, BinaryOperatorsBuildExpectedHLO) {
96   auto test_binary_operator =
97       [&](std::function<XlaOp(XlaOp, XlaOp)> op,
98           ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
99         XlaBuilder b(TestName());
100         op(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
101         TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
102         auto root = module->entry_computation()->root_instruction();
103         EXPECT_THAT(root, matches_pattern);
104       };
105 
106   test_binary_operator([](XlaOp x, XlaOp y) { return x + y; },
107                        op::Add(op::Constant(), op::Constant()));
108   test_binary_operator([](XlaOp x, XlaOp y) { return x - y; },
109                        op::Subtract(op::Constant(), op::Constant()));
110   test_binary_operator([](XlaOp x, XlaOp y) { return x * y; },
111                        op::Multiply(op::Constant(), op::Constant()));
112   test_binary_operator([](XlaOp x, XlaOp y) { return x / y; },
113                        op::Divide(op::Constant(), op::Constant()));
114 
115   test_binary_operator([](XlaOp x, XlaOp y) { return x & y; },
116                        op::And(op::Constant(), op::Constant()));
117   test_binary_operator([](XlaOp x, XlaOp y) { return x | y; },
118                        op::Or(op::Constant(), op::Constant()));
119   test_binary_operator([](XlaOp x, XlaOp y) { return x ^ y; },
120                        op::Xor(op::Constant(), op::Constant()));
121   test_binary_operator([](XlaOp x, XlaOp y) { return x << y; },
122                        op::ShiftLeft(op::Constant(), op::Constant()));
123   test_binary_operator(
124       [](XlaOp x, XlaOp y) { return x >> y; },
125       op::ShiftRightArithmetic(op::Constant(), op::Constant()));
126 
127   auto test_unsigned_binary_operator =
128       [&](std::function<XlaOp(XlaOp, XlaOp)> op,
129           ::testing::Matcher<const ::xla::HloInstruction*> matches_pattern) {
130         XlaBuilder b(TestName());
131         op(ConstantR0<uint32_t>(&b, 1), ConstantR0<uint32_t>(&b, 2));
132         TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
133         auto root = module->entry_computation()->root_instruction();
134         EXPECT_THAT(root, matches_pattern);
135       };
136   test_unsigned_binary_operator(
137       [](XlaOp x, XlaOp y) { return x >> y; },
138       op::ShiftRightLogical(op::Constant(), op::Constant()));
139 }
140 
TEST_F(XlaBuilderTest,VariadicAnd)141 TEST_F(XlaBuilderTest, VariadicAnd) {
142   XlaBuilder b(TestName());
143   Shape s = ShapeUtil::MakeShape(PRED, {});
144   And(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
145       Parameter(&b, 2, s, "p2"));
146   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
147   // Don't specify in the test whether And(x, y, z) is right- or
148   // left-associative; accept either one.
149   EXPECT_THAT(
150       module->entry_computation()->root_instruction(),
151       ::testing::AnyOf(op::And(op::Parameter(0),
152                                op::And(op::Parameter(1), op::Parameter(2))),
153                        op::And(op::And(op::Parameter(0), op::Parameter(1)),
154                                op::Parameter(2))));
155 }
156 
TEST_F(XlaBuilderTest,VariadicOr)157 TEST_F(XlaBuilderTest, VariadicOr) {
158   XlaBuilder b(TestName());
159   Shape s = ShapeUtil::MakeShape(PRED, {});
160   Or(Parameter(&b, 0, s, "p0"), Parameter(&b, 1, s, "p1"),
161      Parameter(&b, 2, s, "p2"));
162   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
163   // Don't specify in the test whether Or(x, y, z) is right- or
164   // left-associative; accept either one.
165   EXPECT_THAT(
166       module->entry_computation()->root_instruction(),
167       ::testing::AnyOf(
168           op::Or(op::Parameter(0), op::Or(op::Parameter(1), op::Parameter(2))),
169           op::Or(op::Or(op::Parameter(0), op::Parameter(1)),
170                  op::Parameter(2))));
171 }
172 
TEST_F(XlaBuilderTest,ShiftRightOperatorOnNonIntegerProducesError)173 TEST_F(XlaBuilderTest, ShiftRightOperatorOnNonIntegerProducesError) {
174   XlaBuilder b(TestName());
175   ConstantR0<float>(&b, 1) >> ConstantR0<float>(&b, 2);
176   auto statusor = b.Build();
177   ASSERT_FALSE(statusor.ok());
178   EXPECT_THAT(
179       statusor.status().error_message(),
180       HasSubstr("Argument to >> operator does not have an integral type"));
181 }
182 
TEST_F(XlaBuilderTest,ParamPlusConstantHasScalarBroadcast)183 TEST_F(XlaBuilderTest, ParamPlusConstantHasScalarBroadcast) {
184   XlaBuilder b(TestName());
185   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {3, 5}), "x");
186   Add(x, ConstantR0<float>(&b, 1.0));
187   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
188   auto root = module->entry_computation()->root_instruction();
189   EXPECT_THAT(root, op::Add(op::Parameter(), op::Broadcast(op::Constant())));
190 }
191 
TEST_F(XlaBuilderTest,ParamPlusParamHasBroadcast)192 TEST_F(XlaBuilderTest, ParamPlusParamHasBroadcast) {
193   XlaBuilder b(TestName());
194   const auto& x_shape = ShapeUtil::MakeShape(S32, {2, 4, 6});
195   const auto& y_shape = ShapeUtil::MakeShape(S32, {2, 4});
196   auto x = Parameter(&b, 0, x_shape, "x");
197   auto y = Parameter(&b, 1, y_shape, "y");
198   auto add = Add(x, y, /*broadcast_dimensions=*/{0, 1});
199 
200   TF_ASSERT_OK_AND_ASSIGN(auto add_shape, b.GetShape(add));
201   EXPECT_TRUE(ShapeUtil::Equal(add_shape, x_shape));
202 
203   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
204   auto root = module->entry_computation()->root_instruction();
205   EXPECT_THAT(root, op::Add(op::Parameter(0), op::Broadcast(op::Parameter(1))));
206 }
207 
TEST_F(XlaBuilderTest,XPlusX)208 TEST_F(XlaBuilderTest, XPlusX) {
209   XlaBuilder b(TestName());
210   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {1, 3, 5, 7}), "x");
211   Add(x, x);
212   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
213   auto root = module->entry_computation()->root_instruction();
214   EXPECT_THAT(root, op::Add(op::Parameter(0), op::Parameter(0)));
215 }
216 
TEST_F(XlaBuilderTest,ShapeInferenceError)217 TEST_F(XlaBuilderTest, ShapeInferenceError) {
218   XlaBuilder b(TestName());
219   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(U32, {2, 4, 6}), "x");
220   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(U32, {2, 4}), "y");
221   Add(x, y);
222   auto statusor = BuildHloModule(&b);
223   ASSERT_FALSE(statusor.ok());
224   EXPECT_THAT(statusor.status().error_message(),
225               HasSubstr("Shapes must be equal rank"));
226 }
227 
TEST_F(XlaBuilderTest,DynamicDimensionReshapeToR0)228 TEST_F(XlaBuilderTest, DynamicDimensionReshapeToR0) {
229   XlaBuilder b(TestName());
230   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1}), "x");
231   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "dyn_dim");
232   auto dx = SetDimensionSize(x, y, 0);
233   Reshape(dx, {});
234   auto statusor = BuildHloModule(&b);
235   ASSERT_TRUE(statusor.ok());
236 }
237 
TEST_F(XlaBuilderTest,ParameterAlreadyRegistered)238 TEST_F(XlaBuilderTest, ParameterAlreadyRegistered) {
239   XlaBuilder b_call("add");
240   Parameter(&b_call, 0, ShapeUtil::MakeShape(PRED, {}), "x");
241 
242   XlaBuilder b(TestName());
243   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "x");
244   auto y = Parameter(&b, 0, ShapeUtil::MakeShape(PRED, {}), "y");
245   Add(x, y);
246   auto statusor = BuildHloModule(&b);
247   ASSERT_FALSE(statusor.ok());
248   EXPECT_THAT(statusor.status().error_message(),
249               HasSubstr("parameter 0 already registered"));
250 }
251 
TEST_F(XlaBuilderTest,Call)252 TEST_F(XlaBuilderTest, Call) {
253   XlaBuilder b_call("the_only_to_apply");
254   auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
255   auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
256   Add(p0, p1);
257   TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
258   XlaBuilder b(TestName());
259   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
260   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
261   auto one = ConstantR0<float>(&b, 1);
262   auto two = ConstantR0<float>(&b, 2);
263   Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
264   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
265   auto root = module->entry_computation()->root_instruction();
266   EXPECT_THAT(root, op::Add(op::Call(op::Parameter(), op::Parameter()),
267                             op::Call(op::Constant(), op::Constant())));
268 }
269 
TEST_F(XlaBuilderTest,BinopHasDegenerateBroadcast)270 TEST_F(XlaBuilderTest, BinopHasDegenerateBroadcast) {
271   XlaBuilder b(TestName());
272   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {1, 2, 3}), "x");
273   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {1, 2, 1}), "y");
274   Add(x, y);
275   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
276 
277   // Expected:
278   //
279   //  x: f32[1,2,3]  y: f32[1,2,1]
280   //      |               |
281   //      |          reshape: f32[1,2]
282   //      |               |
283   //      |          broadcast: f32[1,2,3]
284   //       \             /
285   //            add
286   auto root = module->entry_computation()->root_instruction();
287   EXPECT_THAT(root, op::Add(op::Parameter(0),
288                             op::Broadcast(op::Reshape(op::Parameter(1)))));
289 }
290 
TEST_F(XlaBuilderTest,BinopHasInDimAndDegenerateBroadcast)291 TEST_F(XlaBuilderTest, BinopHasInDimAndDegenerateBroadcast) {
292   XlaBuilder b(TestName());
293   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
294   auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 1, 4}), "y");
295   Add(x, y, /*broadcast_dimensions=*/{0, 1});
296   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
297 
298   // The binary operation has in-dim broadcast and degenerate broadcast, should
299   // first do the in-dim broadcast then convert the degenerate broadcast into a
300   // reshape and a broadcast.
301   //
302   // Expected:
303   //
304   //  x: f32[2,3]            y: f32[2,1,4]
305   //      |                        |
306   //  broadcast: f32[2,3,4]  reshape: f32[2,4]
307   //      |                        |
308   //      |                  broadcast: f32[2,3,4]
309   //       \                      /
310   //                 add
311   auto root = module->entry_computation()->root_instruction();
312   EXPECT_THAT(root, op::Add(op::Broadcast(op::Parameter(0)),
313                             op::Broadcast(op::Reshape(op::Parameter(1)))));
314 }
315 
TEST_F(XlaBuilderTest,BroadcastInDim)316 TEST_F(XlaBuilderTest, BroadcastInDim) {
317   XlaBuilder b(TestName());
318   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3}), "x");
319   BroadcastInDim(x, {2, 4, 3},
320                  /*broadcast_dimensions=*/{0, 2});
321   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
322   auto root = module->entry_computation()->root_instruction();
323   EXPECT_THAT(root, op::Broadcast());
324 }
325 
TEST_F(XlaBuilderTest,BroadcastInDimWithDegeneratedDim)326 TEST_F(XlaBuilderTest, BroadcastInDimWithDegeneratedDim) {
327   XlaBuilder b(TestName());
328   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
329   BroadcastInDim(x, {2, 3, 4},
330                  /*broadcast_dimensions=*/{0, 1, 2});
331   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
332   EXPECT_THAT(module->entry_computation()->root_instruction(),
333               op::Broadcast(op::Reshape(op::Broadcast())));
334 }
335 
TEST_F(XlaBuilderTest,BroadcastInDimWithNegativeSize)336 TEST_F(XlaBuilderTest, BroadcastInDimWithNegativeSize) {
337   XlaBuilder b(TestName());
338   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 1, 4}), "x");
339   BroadcastInDim(x, {-3, 3, 4},
340                  /*broadcast_dimensions=*/{0, 1, 2});
341   auto statusor = BuildHloModule(&b);
342   ASSERT_FALSE(statusor.ok());
343   EXPECT_THAT(statusor.status().error_message(), HasSubstr("invalid shape"));
344 }
345 
TEST_F(XlaBuilderTest,OperandFromWrongBuilder)346 TEST_F(XlaBuilderTest, OperandFromWrongBuilder) {
347   XlaBuilder b1("b1");
348   auto p0 = Parameter(&b1, 0, ShapeUtil::MakeShape(F32, {}), "p0");
349   XlaBuilder builder("main");
350   auto p = Parameter(&builder, 0, ShapeUtil::MakeShape(F32, {}), "p");
351   Add(p, p0);
352   auto statusor = builder.Build();
353   ASSERT_FALSE(statusor.ok());
354   EXPECT_THAT(
355       statusor.status().error_message(),
356       HasSubstr(
357           "built by builder 'b1', but is trying to use it in builder 'main'"));
358 }
359 
TEST_F(XlaBuilderTest,ReshapeDefaultOrder)360 TEST_F(XlaBuilderTest, ReshapeDefaultOrder) {
361   XlaBuilder b(TestName());
362   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
363   Reshape(x, /*new_sizes=*/{6, 35});
364   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
365   auto root = module->entry_computation()->root_instruction();
366   EXPECT_THAT(root, op::Reshape(op::Parameter()));
367 }
368 
TEST_F(XlaBuilderTest,ReshapeHasTranspose)369 TEST_F(XlaBuilderTest, ReshapeHasTranspose) {
370   XlaBuilder b(TestName());
371   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 3, 5, 7}), "x");
372   Reshape(x, /*dimensions=*/{3, 2, 1, 0}, /*new_sizes=*/{6, 35});
373   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
374   auto root = module->entry_computation()->root_instruction();
375   EXPECT_THAT(root, op::Reshape(op::Transpose(op::Parameter())));
376 }
377 
TEST_F(XlaBuilderTest,Transpose)378 TEST_F(XlaBuilderTest, Transpose) {
379   XlaBuilder b(TestName());
380   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
381   Transpose(x, /*permutation=*/{1, 0});
382   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
383   auto root = module->entry_computation()->root_instruction();
384   EXPECT_THAT(root, op::Transpose(op::Parameter()));
385 }
386 
TEST_F(XlaBuilderTest,AllGatherR1)387 TEST_F(XlaBuilderTest, AllGatherR1) {
388   XlaBuilder b(TestName());
389   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4}), "x");
390   AllGather(x, /*all_gather_dimension=*/0, /*shard_count=*/4);
391   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
392   auto root = module->entry_computation()->root_instruction();
393 
394   EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
395   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {16})));
396 }
397 
TEST_F(XlaBuilderTest,AllGatherR2)398 TEST_F(XlaBuilderTest, AllGatherR2) {
399   XlaBuilder b(TestName());
400   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
401   AllGather(x, /*all_gather_dimension=*/1, /*shard_count=*/4);
402   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
403   auto root = module->entry_computation()->root_instruction();
404 
405   EXPECT_EQ(root->opcode(), HloOpcode::kAllGather);
406   EXPECT_TRUE(
407       ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
408 }
409 
TEST_F(XlaBuilderTest,ReduceScatter)410 TEST_F(XlaBuilderTest, ReduceScatter) {
411   XlaBuilder b(TestName());
412   XlaComputation to_apply;
413   {
414     auto sub_builder = b.CreateSubBuilder("add");
415     auto arg0 =
416         Parameter(sub_builder.get(), 0, ShapeUtil::MakeScalarShape(F32), "x");
417     auto arg1 =
418         Parameter(sub_builder.get(), 1, ShapeUtil::MakeScalarShape(F32), "y");
419     Add(arg0, arg1);
420     TF_ASSERT_OK_AND_ASSIGN(to_apply, sub_builder->Build());
421   }
422   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
423   ReplicaGroup group;
424   group.add_replica_ids(0);
425   group.add_replica_ids(1);
426   ReduceScatter(x, to_apply, /*scatter_dimension=*/1, /*shard_count=*/2,
427                 /*replica_groups=*/{group});
428   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
429   auto root = module->entry_computation()->root_instruction();
430 
431   EXPECT_EQ(root->opcode(), HloOpcode::kReduceScatter);
432   EXPECT_TRUE(
433       ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
434 }
435 
TEST_F(XlaBuilderTest,AllToAll)436 TEST_F(XlaBuilderTest, AllToAll) {
437   XlaBuilder b(TestName());
438   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
439   AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0,
440            /*split_count=*/2);
441   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
442   auto root = module->entry_computation()->root_instruction();
443 
444   // AllToAll is decomposed into slices -> all-to-all -> gte -> concat.
445   EXPECT_EQ(root->opcode(), HloOpcode::kReshape);
446   EXPECT_EQ(root->operand(0)->operand(0)->operand(0)->opcode(),
447             HloOpcode::kAllToAll);
448   EXPECT_TRUE(
449       ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {8, 8})));
450 }
451 
452 // Test the special case where split_dimension is the same as concat_dimension.
TEST_F(XlaBuilderTest,AllToAllSpecial)453 TEST_F(XlaBuilderTest, AllToAllSpecial) {
454   XlaBuilder b(TestName());
455   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16, 8}), "x");
456   AllToAll(x, /*split_dimension=*/0, /*concat_dimension=*/0,
457            /*split_count=*/2);
458   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
459   auto root = module->entry_computation()->root_instruction();
460 
461   // AllToAll is converted into a single all-to-all HloInstruction.
462   EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
463   EXPECT_TRUE(
464       ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 16, 8})));
465 }
466 
TEST_F(XlaBuilderTest,AllToAllTuple)467 TEST_F(XlaBuilderTest, AllToAllTuple) {
468   XlaBuilder b(TestName());
469   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {2, 4}), "p0");
470   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {2, 4}), "p1");
471   ReplicaGroup replica_group;
472   replica_group.add_replica_ids(0);
473   replica_group.add_replica_ids(1);
474 
475   AllToAllTuple({p0, p1}, {replica_group}, LayoutUtil::MakeAscendingLayout(2));
476   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
477   auto root = module->entry_computation()->root_instruction();
478 
479   // AllToAll is converted into a single all-to-all HloInstruction.
480   EXPECT_EQ(root->opcode(), HloOpcode::kAllToAll);
481   auto expected_shape =
482       ShapeUtil::MakeShapeWithLayout(F32, /* dimensions= */ {2, 4},
483                                      /* minor_to_major= */ {0, 1});
484   EXPECT_THAT(root, op::ShapeWithLayout(ShapeUtil::MakeTupleShape(
485                         {expected_shape, expected_shape})));
486   EXPECT_THAT(root, op::ReplicaGroups({{0, 1}}));
487 }
488 
TEST_F(XlaBuilderTest,CollectivePermute)489 TEST_F(XlaBuilderTest, CollectivePermute) {
490   XlaBuilder b(TestName());
491   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
492   CollectivePermute(x, {{0, 1}, {1, 2}, {2, 3}});
493   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
494   auto root = module->entry_computation()->root_instruction();
495   EXPECT_EQ(root->opcode(), HloOpcode::kCollectivePermute);
496 }
497 
TEST_F(XlaBuilderTest,GetDimensionSize)498 TEST_F(XlaBuilderTest, GetDimensionSize) {
499   XlaBuilder b(TestName());
500   auto x =
501       Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
502   GetDimensionSize(x, 1);
503   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
504   auto root = module->entry_computation()->root_instruction();
505   EXPECT_EQ(root->opcode(), HloOpcode::kGetDimensionSize);
506 }
507 
TEST_F(XlaBuilderTest,GetDimensionSizeConstant)508 TEST_F(XlaBuilderTest, GetDimensionSizeConstant) {
509   XlaBuilder b(TestName());
510   auto x =
511       Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}, {false, true}), "x");
512   // Get dimension size from a contant dimension gives us a constant.
513   GetDimensionSize(x, 0);
514   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
515   auto root = module->entry_computation()->root_instruction();
516   EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
517 }
518 
TEST_F(XlaBuilderTest,ReportError)519 TEST_F(XlaBuilderTest, ReportError) {
520   XlaBuilder b(TestName());
521   auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {5, 7}), "x");
522   Add(b.ReportError(InvalidArgument("a test error")), x);
523   auto statusor = b.Build();
524   ASSERT_FALSE(statusor.ok());
525   EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
526 }
527 
TEST_F(XlaBuilderTest,ReportErrorOrReturnHandlesNonErrors)528 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesNonErrors) {
529   XlaBuilder b(TestName());
530   StatusOr<XlaOp> op(ConstantR0<float>(&b, 1.0));
531   Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
532   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
533   auto root = module->entry_computation()->root_instruction();
534   EXPECT_THAT(root, op::Add(op::Constant(), op::Constant()));
535 }
536 
TEST_F(XlaBuilderTest,ReportErrorOrReturnHandlesErrors)537 TEST_F(XlaBuilderTest, ReportErrorOrReturnHandlesErrors) {
538   XlaBuilder b(TestName());
539   StatusOr<XlaOp> op(InvalidArgument("a test error"));
540   Add(b.ReportErrorOrReturn(op), ConstantR0<float>(&b, 2.0));
541   auto statusor = b.Build();
542   ASSERT_FALSE(statusor.ok());
543   EXPECT_THAT(statusor.status().error_message(), HasSubstr("a test error"));
544 }
545 
TEST_F(XlaBuilderTest,BuildWithSpecificRoot)546 TEST_F(XlaBuilderTest, BuildWithSpecificRoot) {
547   XlaBuilder b(TestName());
548   XlaOp constant = ConstantR0<float>(&b, 1.0);
549   Add(constant, ConstantR0<float>(&b, 2.0));
550   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/constant));
551   auto root = module->entry_computation()->root_instruction();
552   EXPECT_THAT(root, op::Constant());
553 }
554 
TEST_F(XlaBuilderTest,BuildWithSpecificRootAndMultipleParameters)555 TEST_F(XlaBuilderTest, BuildWithSpecificRootAndMultipleParameters) {
556   // Specifying a particular root in Build should still include all entry
557   // parameters.
558   XlaBuilder b(TestName());
559   const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
560   XlaOp x = Parameter(&b, 0, shape, "x");
561   XlaOp y = Parameter(&b, 1, shape, "y");
562   XlaOp z = Parameter(&b, 2, shape, "z");
563   Add(x, Sub(y, z));
564   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/x));
565   auto root = module->entry_computation()->root_instruction();
566   EXPECT_THAT(root, op::Parameter());
567   EXPECT_EQ(module->entry_computation()->num_parameters(), 3);
568   EXPECT_EQ(module->entry_computation()->instruction_count(), 5);
569 }
570 
TEST_F(XlaBuilderTest,BuildWithSpecificRootWithWrongBuilder)571 TEST_F(XlaBuilderTest, BuildWithSpecificRootWithWrongBuilder) {
572   XlaBuilder b(TestName());
573   XlaBuilder other_b(TestName());
574   const Shape shape = ShapeUtil::MakeShape(F32, {42, 123});
575 
576   Parameter(&b, 0, shape, "param");
577   XlaOp other_param = Parameter(&other_b, 0, shape, "other_param");
578 
579   Status status = b.Build(other_param).status();
580   ASSERT_IS_NOT_OK(status);
581   EXPECT_THAT(
582       status.error_message(),
583       ::testing::HasSubstr("root operation is not in this computation"));
584 }
585 
TEST_F(XlaBuilderTest,ProtoMatches)586 TEST_F(XlaBuilderTest, ProtoMatches) {
587   std::vector<XlaComputation> computations;
588   const int n = 2;
589   computations.reserve(n);
590   for (int i = 0; i < n; ++i) {
591     XlaBuilder b_call("the_only_to_apply");
592     auto p0 = Parameter(&b_call, 0, ShapeUtil::MakeShape(F32, {}), "p0");
593     auto p1 = Parameter(&b_call, 1, ShapeUtil::MakeShape(F32, {}), "p1");
594     Add(p0, Add(p1, p0));
595     TF_ASSERT_OK_AND_ASSIGN(auto call, b_call.Build());
596     XlaBuilder b(TestName());
597     auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {}), "x");
598     auto y = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {}), "y");
599     auto one = ConstantR0<float>(&b, 1);
600     auto two = ConstantR0<float>(&b, 2);
601     Add(Call(&b, call, {x, y}), Call(&b, call, {one, two}));
602     computations.push_back(b.Build().ValueOrDie());
603   }
604   auto c0_string = computations[0].proto().SerializeAsString();
605   auto c1_string = computations[1].proto().SerializeAsString();
606   EXPECT_EQ(c0_string, c1_string);
607 }
608 
TEST_F(XlaBuilderTest,DynamicParameter)609 TEST_F(XlaBuilderTest, DynamicParameter) {
610   XlaBuilder b(TestName());
611   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
612       {ShapeUtil::MakeShape(F32, {5}), ShapeUtil::MakeShape(F32, {6}, {true})});
613   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
614   Parameter(&b, 1, ShapeUtil::MakeShape(U32, {}), "p1");
615   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/1,
616                                    /*dynamic_size_param_index=*/{},
617                                    /*target_param_num=*/0,
618                                    /*target_param_index=*/{1},
619                                    /*target_dim_num=*/0));
620   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, /*root=*/p0));
621   const Shape& param_shape = module->entry_computation()
622                                  ->parameter_instruction(0)
623                                  ->shape()
624                                  .tuple_shapes(1);
625   EXPECT_TRUE(param_shape.is_dynamic_dimension(0));
626 }
627 
TEST_F(XlaBuilderTest,SetDimensionSize)628 TEST_F(XlaBuilderTest, SetDimensionSize) {
629   XlaBuilder b(TestName());
630   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
631   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
632   auto set_dim_size = SetDimensionSize(p0, p1, 0);
633   TF_ASSERT_OK_AND_ASSIGN(auto module,
634                           BuildHloModule(&b, /*root=*/set_dim_size));
635   const Shape& root_shape =
636       module->entry_computation()->root_instruction()->shape();
637   EXPECT_TRUE(root_shape.is_dynamic_dimension(0));
638 }
639 
TEST_F(XlaBuilderTest,RemoveDynamicDimension)640 TEST_F(XlaBuilderTest, RemoveDynamicDimension) {
641   XlaBuilder b(TestName());
642   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0");
643   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
644   auto set_dim_size = SetDimensionSize(p0, p1, 0);
645   auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0);
646   TF_ASSERT_OK_AND_ASSIGN(auto module,
647                           BuildHloModule(&b, /*root=*/remove_dim_size));
648   const Shape& root_shape =
649       module->entry_computation()->root_instruction()->shape();
650   // Dynamic dimension has been removed.
651   EXPECT_FALSE(root_shape.is_dynamic_dimension(0));
652 }
653 
TEST_F(XlaBuilderTest,RemoveDynamicDimensionMultiDims)654 TEST_F(XlaBuilderTest, RemoveDynamicDimensionMultiDims) {
655   XlaBuilder b(TestName());
656   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10, 10}), "p0");
657   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1");
658   auto set_dim_size = SetDimensionSize(p0, p1, 0);
659   set_dim_size = SetDimensionSize(set_dim_size, p1, 1);
660   auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0);
661   remove_dim_size = RemoveDynamicDimension(remove_dim_size, 1);
662   TF_ASSERT_OK_AND_ASSIGN(auto module,
663                           BuildHloModule(&b, /*root=*/remove_dim_size));
664   const Shape& root_shape =
665       module->entry_computation()->root_instruction()->shape();
666   // Dynamic dimensions are removed.
667   EXPECT_FALSE(root_shape.is_dynamic_dimension(0));
668   EXPECT_FALSE(root_shape.is_dynamic_dimension(1));
669 }
670 
TEST_F(XlaBuilderTest,DynamicUnary)671 TEST_F(XlaBuilderTest, DynamicUnary) {
672   XlaBuilder b(TestName());
673   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
674       {ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
675   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
676   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
677                                    /*dynamic_size_param_index=*/{1},
678                                    /*target_param_num=*/0,
679                                    /*target_param_index=*/{0},
680                                    /*target_dim_num=*/0));
681   auto gte = GetTupleElement(p0, 0);
682   Neg(gte);
683   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
684   const Shape& result_shape =
685       module->entry_computation()->root_instruction()->shape();
686   EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
687 }
688 
TEST_F(XlaBuilderTest,DynamicBinary)689 TEST_F(XlaBuilderTest, DynamicBinary) {
690   XlaBuilder b(TestName());
691   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
692       {ShapeUtil::MakeShape(F32, {5}, {true}),
693        ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
694   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
695   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
696                                    /*dynamic_size_param_index=*/{2},
697                                    /*target_param_num=*/0,
698                                    /*target_param_index=*/{0},
699                                    /*target_dim_num=*/0));
700   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
701                                    /*dynamic_size_param_index=*/{2},
702                                    /*target_param_num=*/0,
703                                    /*target_param_index=*/{1},
704                                    /*target_dim_num=*/0));
705   auto gte0 = GetTupleElement(p0, 0);
706   auto gte1 = GetTupleElement(p0, 1);
707   Add(gte0, gte1);
708   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
709   const Shape& result_shape =
710       module->entry_computation()->root_instruction()->shape();
711   EXPECT_TRUE(result_shape.is_dynamic_dimension(0));
712 }
713 
TEST_F(XlaBuilderTest,DynamicBinaryHasBroadcast)714 TEST_F(XlaBuilderTest, DynamicBinaryHasBroadcast) {
715   XlaBuilder b(TestName());
716   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
717       {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
718        ShapeUtil::MakeShape(F32, {5}, {true}), ShapeUtil::MakeShape(U32, {})});
719   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
720   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
721                                    /*dynamic_size_param_index=*/{2},
722                                    /*target_param_num=*/0,
723                                    /*target_param_index=*/{0},
724                                    /*target_dim_num=*/0));
725   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
726                                    /*dynamic_size_param_index=*/{2},
727                                    /*target_param_num=*/0,
728                                    /*target_param_index=*/{1},
729                                    /*target_dim_num=*/0));
730   auto gte0 = GetTupleElement(p0, 0);
731   auto gte1 = GetTupleElement(p0, 1);
732   Add(gte0, gte1, {0});
733   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
734   const Shape& result_shape =
735       module->entry_computation()->root_instruction()->shape();
736   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
737       << result_shape;
738 }
739 
TEST_F(XlaBuilderTest,DynamicBroadcast)740 TEST_F(XlaBuilderTest, DynamicBroadcast) {
741   XlaBuilder b(TestName());
742   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
743       {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
744        ShapeUtil::MakeShape(U32, {})});
745   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
746   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
747                                    /*dynamic_size_param_index=*/{1},
748                                    /*target_param_num=*/0,
749                                    /*target_param_index=*/{0},
750                                    /*target_dim_num=*/0));
751   auto gte = GetTupleElement(p0, 0);
752   BroadcastInDim(gte, /*out_dim_size=*/{3, 5, 4},
753                  /*broadcast_dimensions=*/{1, 2});
754   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
755   const Shape& result_shape =
756       module->entry_computation()->root_instruction()->shape();
757   EXPECT_TRUE(
758       ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
759       << result_shape;
760 }
761 
TEST_F(XlaBuilderTest,DynamicBinaryHasDegenerateBroadcast)762 TEST_F(XlaBuilderTest, DynamicBinaryHasDegenerateBroadcast) {
763   XlaBuilder b(TestName());
764   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
765       {ShapeUtil::MakeShape(F32, {10}, {true}),
766        ShapeUtil::MakeShape(F32, {1, 15}), ShapeUtil::MakeShape(U32, {})});
767   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
768   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
769                                    /*dynamic_size_param_index=*/{1},
770                                    /*target_param_num=*/0,
771                                    /*target_param_index=*/{0},
772                                    /*target_dim_num=*/0));
773   auto gte0 = GetTupleElement(p0, 0);
774   auto gte1 = GetTupleElement(p0, 1);
775   Add(gte0, gte1, /*broadcast_dimensions=*/{0});  // f32[<=10, 15]
776   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
777   const Shape& result_shape =
778       module->entry_computation()->root_instruction()->shape();
779   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
780       << result_shape;
781 }
782 
TEST_F(XlaBuilderTest,DynamicSelectOnlyPredDynamic)783 TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) {
784   XlaBuilder b(TestName());
785   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
786       {ShapeUtil::MakeShape(PRED, {10}, {true}),
787        ShapeUtil::MakeShape(F32, {10}), ShapeUtil::MakeShape(U32, {})});
788   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
789   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
790                                    /*dynamic_size_param_index=*/{1},
791                                    /*target_param_num=*/0,
792                                    /*target_param_index=*/{0},
793                                    /*target_dim_num=*/0));
794   auto gte0 = GetTupleElement(p0, 0);
795   auto gte1 = GetTupleElement(p0, 1);
796 
797   Select(gte0, gte1, gte1);
798 
799   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
800   const Shape& result_shape =
801       module->entry_computation()->root_instruction()->shape();
802   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true}))
803       << result_shape;
804 }
805 
TEST_F(XlaBuilderTest,SelectIntoConditional)806 TEST_F(XlaBuilderTest, SelectIntoConditional) {
807   XlaBuilder b(TestName());
808   Shape selector_shape = ShapeUtil::MakeShape(PRED, {});
809   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
810       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})});
811   XlaOp p0 = Parameter(&b, 0, selector_shape, "p0");
812   XlaOp p1 = Parameter(&b, 1, tuple_param_shape, "p1");
813   XlaOp p2 = Parameter(&b, 2, tuple_param_shape, "p2");
814 
815   Select(p0, p1, p2);
816 
817   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
818                           BuildHloModule(&b));
819   EXPECT_THAT(
820       module->entry_computation()->root_instruction(),
821       op::Conditional(op::Parameter(0), op::Parameter(1), op::Parameter(2)));
822   EXPECT_THAT(module->entry_computation()
823                   ->root_instruction()
824                   ->branch_computation(0)
825                   ->root_instruction(),
826               op::Parameter(0));
827   EXPECT_THAT(module->entry_computation()
828                   ->root_instruction()
829                   ->branch_computation(1)
830                   ->root_instruction(),
831               op::Parameter(0));
832 }
833 
TEST_F(XlaBuilderTest,DynamicPad)834 TEST_F(XlaBuilderTest, DynamicPad) {
835   XlaBuilder b(TestName());
836   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
837       {ShapeUtil::MakeShape(F32, {5, 4}, {true, false}),
838        ShapeUtil::MakeShape(U32, {})});
839   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
840   auto pad_val = ConstantR0<float>(&b, -1);
841   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
842                                    /*dynamic_size_param_index=*/{1},
843                                    /*target_param_num=*/0,
844                                    /*target_param_index=*/{0},
845                                    /*target_dim_num=*/0));
846   auto gte = GetTupleElement(p0, 0);
847   PaddingConfig padding_config;
848   for (int i = 0; i < 2; i++) {
849     auto dimension = padding_config.add_dimensions();
850     dimension->set_edge_padding_low(0);
851     dimension->set_edge_padding_high(0);
852     dimension->set_interior_padding(0);
853   }
854   Pad(gte, pad_val, padding_config);
855   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
856   const Shape& result_shape =
857       module->entry_computation()->root_instruction()->shape();
858   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
859       << result_shape;
860 }
861 
TEST_F(XlaBuilderTest,DynamicConvolution)862 TEST_F(XlaBuilderTest, DynamicConvolution) {
863   XlaBuilder b(TestName());
864   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
865       {ShapeUtil::MakeShape(F32, {1, 2, 2, 128}, {true, false, false, false}),
866        ShapeUtil::MakeShape(F32, {2, 2, 128, 8}, {false, false, true, false}),
867        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
868   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
869   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
870                                    /*dynamic_size_param_index=*/{2},
871                                    /*target_param_num=*/0,
872                                    /*target_param_index=*/{0},
873                                    /*target_dim_num=*/0));
874   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
875                                    /*dynamic_size_param_index=*/{3},
876                                    /*target_param_num=*/0,
877                                    /*target_param_index=*/{1},
878                                    /*target_dim_num=*/2));
879   auto input = GetTupleElement(p0, 0);
880   auto filter = GetTupleElement(p0, 1);
881   ConvolutionDimensionNumbers dnums;
882   dnums.set_input_batch_dimension(0);
883   dnums.set_output_batch_dimension(0);
884   dnums.add_input_spatial_dimensions(1);
885   dnums.add_output_spatial_dimensions(1);
886   dnums.add_input_spatial_dimensions(2);
887   dnums.add_output_spatial_dimensions(2);
888   dnums.set_input_feature_dimension(3);
889   dnums.set_output_feature_dimension(3);
890   dnums.add_kernel_spatial_dimensions(0);
891   dnums.add_kernel_spatial_dimensions(1);
892   dnums.set_kernel_input_feature_dimension(2);
893   dnums.set_kernel_output_feature_dimension(3);
894   ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, dnums,
895                             /*feature_group_count=*/1);
896   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
897   const Shape& result_shape =
898       module->entry_computation()->root_instruction()->shape();
899   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
900                               {true, false, false, false}))
901       << result_shape;
902 }
903 
TEST_F(XlaBuilderTest,DynamicDot)904 TEST_F(XlaBuilderTest, DynamicDot) {
905   XlaBuilder b(TestName());
906   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
907       {ShapeUtil::MakeShape(F32, {2, 3, 4}, {true, true, false}),
908        ShapeUtil::MakeShape(F32, {2, 4, 5}, {true, false, false}),
909        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
910   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
911   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
912                                    /*dynamic_size_param_index=*/{2},
913                                    /*target_param_num=*/0,
914                                    /*target_param_index=*/{0},
915                                    /*target_dim_num=*/0));
916   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
917                                    /*dynamic_size_param_index=*/{2},
918                                    /*target_param_num=*/0,
919                                    /*target_param_index=*/{1},
920                                    /*target_dim_num=*/0));
921   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
922                                    /*dynamic_size_param_index=*/{3},
923                                    /*target_param_num=*/0,
924                                    /*target_param_index=*/{0},
925                                    /*target_dim_num=*/1));
926 
927   auto lhs = GetTupleElement(p0, 0);
928   auto rhs = GetTupleElement(p0, 1);
929   DotDimensionNumbers dnums;
930   dnums.add_lhs_contracting_dimensions(2);
931   dnums.add_rhs_contracting_dimensions(1);
932   dnums.add_lhs_batch_dimensions(0);
933   dnums.add_rhs_batch_dimensions(0);
934   DotGeneral(lhs, rhs, dnums);
935   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
936   const Shape& result_shape =
937       module->entry_computation()->root_instruction()->shape();
938   EXPECT_TRUE(
939       ContainersEqual(result_shape.dynamic_dimensions(), {true, true, false}))
940       << result_shape;
941 }
942 
TEST_F(XlaBuilderTest,DynamicReduce)943 TEST_F(XlaBuilderTest, DynamicReduce) {
944   XlaBuilder b(TestName());
945   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
946       {ShapeUtil::MakeShape(F32, {5, 4, 3}, {false, true, false}),
947        ShapeUtil::MakeShape(U32, {})});
948   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
949   auto init = ConstantR0<float>(&b, 0);
950   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
951                                    /*dynamic_size_param_index=*/{1},
952                                    /*target_param_num=*/0,
953                                    /*target_param_index=*/{0},
954                                    /*target_dim_num=*/1));
955   auto gte = GetTupleElement(p0, 0);
956   XlaBuilder bsum(TestName());
957   Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
958       Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
959   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
960   Reduce(gte, init, sum, {0});
961   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
962   const Shape& result_shape =
963       module->entry_computation()->root_instruction()->shape();
964   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true, false}))
965       << result_shape;
966 }
967 
TEST_F(XlaBuilderTest,DynamicReduceWindow)968 TEST_F(XlaBuilderTest, DynamicReduceWindow) {
969   XlaBuilder b(TestName());
970   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
971       {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
972        ShapeUtil::MakeShape(U32, {})});
973   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
974   auto init = ConstantR0<float>(&b, 0.f);
975   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
976                                    /*dynamic_size_param_index=*/{1},
977                                    /*target_param_num=*/0,
978                                    /*target_param_index=*/{0},
979                                    /*target_dim_num=*/0));
980   auto gte = GetTupleElement(p0, 0);
981   XlaBuilder bsum(TestName());
982   Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
983       Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
984   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
985   ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4},
986                /*window_strides=*/{1, 1, 1}, Padding::kValid);
987   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
988   VLOG(2) << module->entry_computation()->root_instruction()->ToString()
989           << "\n";
990   const Shape& result_shape =
991       module->entry_computation()->root_instruction()->shape();
992   EXPECT_TRUE(
993       ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
994       << result_shape;
995 }
996 
TEST_F(XlaBuilderTest,VariadicDynamicReduceWindow)997 TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) {
998   XlaBuilder b(TestName());
999   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1000       {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
1001        ShapeUtil::MakeShape(U32, {})});
1002   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1003   auto p1 = Parameter(&b, 1, tuple_param_shape, "p1");
1004   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1005                                    /*dynamic_size_param_index=*/{1},
1006                                    /*target_param_num=*/0,
1007                                    /*target_param_index=*/{0},
1008                                    /*target_dim_num=*/0));
1009   auto gte0 = GetTupleElement(p0, 0);
1010   auto gte1 = GetTupleElement(p1, 0);
1011   std::vector<XlaOp> input_operands = {gte0, gte1};
1012   XlaBuilder bsum(TestName());
1013   auto p2 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x0");
1014   auto p3 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "x1");
1015   auto p4 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "y0");
1016   auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1");
1017   std::vector<XlaOp> output_operands = {Add(p2, p4), Add(p3, p5)};
1018   Tuple(&bsum, absl::MakeSpan(output_operands));
1019   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
1020   auto init = ConstantR0<float>(&b, 0.f);
1021   ReduceWindow(input_operands, {init, init}, sum,
1022                /*window_dimensions=*/{1, 2, 4},
1023                /*window_strides=*/{1, 1, 1}, Padding::kValid);
1024   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1025   VLOG(2) << module->entry_computation()->root_instruction()->ToString()
1026           << "\n";
1027   const Shape& result_shape =
1028       module->entry_computation()->root_instruction()->shape();
1029   EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(0).dynamic_dimensions(),
1030                               {true, false, false}))
1031       << result_shape.tuple_shapes(0);
1032   EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(1).dynamic_dimensions(),
1033                               {true, false, false}))
1034       << result_shape.tuple_shapes(1);
1035 }
1036 
TEST_F(XlaBuilderTest,DynamicSelectAndScatter)1037 TEST_F(XlaBuilderTest, DynamicSelectAndScatter) {
1038   XlaBuilder b(TestName());
1039   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1040       {ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
1041        ShapeUtil::MakeShape(F32, {2, 2, 2}, {true, false, false}),
1042        ShapeUtil::MakeShape(U32, {})});
1043   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1044   auto init = ConstantR0<float>(&b, 0.f);
1045   XlaBuilder bsum(TestName());
1046   Add(Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x"),
1047       Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "y"));
1048   TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
1049   XlaBuilder bge(TestName());
1050   Ge(Parameter(&bge, 0, ShapeUtil::MakeShape(F32, {}), "x"),
1051      Parameter(&bge, 1, ShapeUtil::MakeShape(F32, {}), "y"));
1052   TF_ASSERT_OK_AND_ASSIGN(auto ge, bge.Build());
1053 
1054   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1055                                    /*dynamic_size_param_index=*/{2},
1056                                    /*target_param_num=*/0,
1057                                    /*target_param_index=*/{0},
1058                                    /*target_dim_num=*/0));
1059   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1060                                    /*dynamic_size_param_index=*/{2},
1061                                    /*target_param_num=*/0,
1062                                    /*target_param_index=*/{1},
1063                                    /*target_dim_num=*/0));
1064   auto gte0 = GetTupleElement(p0, 0);
1065   auto source = GetTupleElement(p0, 1);
1066   SelectAndScatter(gte0, ge, {1, 2, 4}, {1, 2, 4}, Padding::kValid, source,
1067                    init, sum);
1068   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1069   const Shape& result_shape =
1070       module->entry_computation()->root_instruction()->shape();
1071   EXPECT_TRUE(
1072       ContainersEqual(result_shape.dynamic_dimensions(), {true, false, false}))
1073       << result_shape;
1074 }
1075 
TEST_F(XlaBuilderTest,DynamicReshape)1076 TEST_F(XlaBuilderTest, DynamicReshape) {
1077   XlaBuilder b(TestName());
1078   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1079       {ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6},
1080                             {false, false, true, true, false}),
1081        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
1082   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1083   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1084                                    /*dynamic_size_param_index=*/{1},
1085                                    /*target_param_num=*/0,
1086                                    /*target_param_index=*/{0},
1087                                    /*target_dim_num=*/2));
1088   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1089                                    /*dynamic_size_param_index=*/{2},
1090                                    /*target_param_num=*/0,
1091                                    /*target_param_index=*/{0},
1092                                    /*target_dim_num=*/3));
1093   auto gte = GetTupleElement(p0, 0);  // f32[2, 3, <=4, <=5, 6]
1094   Reshape(gte, /*new_sizes=*/{6, 4, 5, 2, 3});
1095   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1096   const Shape& result_shape =
1097       module->entry_computation()->root_instruction()->shape();
1098   EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
1099   EXPECT_TRUE(result_shape.is_dynamic_dimension(2));
1100   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(),
1101                               {false, true, true, false, false}))
1102       << result_shape;
1103 }
1104 
TEST_F(XlaBuilderTest,DynamicSelect)1105 TEST_F(XlaBuilderTest, DynamicSelect) {
1106   XlaBuilder b(TestName());
1107   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1108       {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
1109        ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
1110        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
1111   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1112   auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
1113   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1114                                    /*dynamic_size_param_index=*/{2},
1115                                    /*target_param_num=*/0,
1116                                    /*target_param_index=*/{0},
1117                                    /*target_dim_num=*/1));
1118   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1119                                    /*dynamic_size_param_index=*/{3},
1120                                    /*target_param_num=*/0,
1121                                    /*target_param_index=*/{1},
1122                                    /*target_dim_num=*/1));
1123   auto gte0 = GetTupleElement(p0, 0);
1124   auto gte1 = GetTupleElement(p0, 1);
1125   Select(pred, gte0, gte1);
1126   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1127   const Shape& result_shape =
1128       module->entry_computation()->root_instruction()->shape();
1129   EXPECT_TRUE(result_shape.is_dynamic_dimension(1));
1130   EXPECT_FALSE(result_shape.is_dynamic_dimension(2));
1131   EXPECT_TRUE(
1132       ContainersEqual(result_shape.dynamic_dimensions(), {false, true, false}))
1133       << result_shape;
1134 }
1135 
TEST_F(XlaBuilderTest,DynamicSelectNotCompatible)1136 TEST_F(XlaBuilderTest, DynamicSelectNotCompatible) {
1137   XlaBuilder b(TestName());
1138   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1139       {ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, true, false}),
1140        ShapeUtil::MakeShape(F32, {4, 5, 6}, {false, false, true}),
1141        ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})});
1142   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1143   auto pred = Parameter(&b, 1, ShapeUtil::MakeShape(PRED, {}), "pred");
1144   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1145                                    /*dynamic_size_param_index=*/{2},
1146                                    /*target_param_num=*/0,
1147                                    /*target_param_index=*/{0},
1148                                    /*target_dim_num=*/1));
1149   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1150                                    /*dynamic_size_param_index=*/{3},
1151                                    /*target_param_num=*/0,
1152                                    /*target_param_index=*/{1},
1153                                    /*target_dim_num=*/2));
1154   auto gte0 = GetTupleElement(p0, 0);  // f32[4,<=5,6]
1155   auto gte1 = GetTupleElement(p0, 1);  // f32[4,5,<=6]
1156   Select(pred, gte0, gte1);
1157   Status status = BuildHloModule(&b).status();
1158   ASSERT_IS_OK(status);
1159 }
1160 
TEST_F(XlaBuilderTest,DynamicTranspose)1161 TEST_F(XlaBuilderTest, DynamicTranspose) {
1162   XlaBuilder b(TestName());
1163   Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
1164       {ShapeUtil::MakeShape(F32, {3, 5}, {true, false}),
1165        ShapeUtil::MakeShape(U32, {})});
1166   auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
1167   ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
1168                                    /*dynamic_size_param_index=*/{1},
1169                                    /*target_param_num=*/0,
1170                                    /*target_param_index=*/{0},
1171                                    /*target_dim_num=*/0));
1172   auto gte = GetTupleElement(p0, 0);
1173   Transpose(gte, /*permutation=*/{1, 0});
1174   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1175   const Shape& result_shape =
1176       module->entry_computation()->root_instruction()->shape();
1177   EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {false, true}))
1178       << result_shape;
1179 }
1180 
TEST_F(XlaBuilderTest,DotWithPreferredElementType)1181 TEST_F(XlaBuilderTest, DotWithPreferredElementType) {
1182   XlaBuilder b(TestName());
1183   Shape p0_shape = ShapeUtil::MakeShape(U8, {2, 3});
1184   Shape p1_shape = ShapeUtil::MakeShape(U16, {3, 2});
1185   auto p0 = Parameter(&b, 0, p0_shape, "p0");
1186   auto p1 = Parameter(&b, 1, p1_shape, "p1");
1187 
1188   DotDimensionNumbers dnums;
1189   dnums.add_lhs_contracting_dimensions(1);
1190   dnums.add_rhs_contracting_dimensions(0);
1191   DotGeneral(p0, p1, dnums, /*precision_config=*/nullptr,
1192              /*preferred_element_type=*/U32);
1193   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1194   const Shape& result_shape =
1195       module->entry_computation()->root_instruction()->shape();
1196   ASSERT_TRUE(
1197       ShapeUtil::Equal(ShapeUtil::MakeShape(U32, {2, 2}), result_shape));
1198 }
1199 
TEST_F(XlaBuilderTest,ConvolutionWithPreferredElementType)1200 TEST_F(XlaBuilderTest, ConvolutionWithPreferredElementType) {
1201   XlaBuilder b(TestName());
1202   Shape p0_shape = ShapeUtil::MakeShape(S16, {1, 2, 2, 128});
1203   Shape p1_shape = ShapeUtil::MakeShape(S8, {2, 2, 128, 8});
1204   auto p0 = Parameter(&b, 0, p0_shape, "p0");
1205   auto p1 = Parameter(&b, 1, p1_shape, "p1");
1206 
1207   ConvolutionDimensionNumbers dnums;
1208   dnums.set_input_batch_dimension(0);
1209   dnums.set_output_batch_dimension(0);
1210   dnums.add_input_spatial_dimensions(1);
1211   dnums.add_output_spatial_dimensions(1);
1212   dnums.add_input_spatial_dimensions(2);
1213   dnums.add_output_spatial_dimensions(2);
1214   dnums.set_input_feature_dimension(3);
1215   dnums.set_output_feature_dimension(3);
1216   dnums.add_kernel_spatial_dimensions(0);
1217   dnums.add_kernel_spatial_dimensions(1);
1218   dnums.set_kernel_input_feature_dimension(2);
1219   dnums.set_kernel_output_feature_dimension(3);
1220   ConvWithGeneralDimensions(p0, p1, {1, 1}, Padding::kValid, dnums,
1221                             /*feature_group_count=*/1, /*batch_group_count=*/1,
1222                             /*precision_config=*/nullptr,
1223                             /*preferred_element_type=*/S32);
1224   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1225   const Shape& result_shape =
1226       module->entry_computation()->root_instruction()->shape();
1227   ASSERT_TRUE(
1228       ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {1, 1, 1, 8}), result_shape));
1229 }
1230 
TEST_F(XlaBuilderTest,AfterAllWithNonTokenOperands)1231 TEST_F(XlaBuilderTest, AfterAllWithNonTokenOperands) {
1232   XlaBuilder b(TestName());
1233   AfterAll(&b, {CreateToken(&b), ConstantR0<float>(&b, 1.0)});
1234   Status status = b.Build().status();
1235   ASSERT_IS_NOT_OK(status);
1236   EXPECT_THAT(status.error_message(),
1237               ::testing::HasSubstr("All operands to AfterAll must be tokens"));
1238 }
1239 
TEST_F(XlaBuilderTest,CheckInputOutputAlias)1240 TEST_F(XlaBuilderTest, CheckInputOutputAlias) {
1241   XlaBuilder b(TestName());
1242   auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {8, 4}), "p0");
1243   auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(F32, {8, 4}), "p1");
1244   auto add = Add(p0, p1);
1245   auto sub = Sub(p0, p1);
1246   auto root = Tuple(&b, {add, sub});
1247 
1248   b.SetUpAlias({1}, 0, {});
1249   b.SetUpAlias({0}, 1, {});
1250 
1251   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b, root));
1252 
1253   const HloInputOutputAliasConfig& config = module->input_output_alias_config();
1254   EXPECT_TRUE(config.ParameterHasAlias(0, {}));
1255   EXPECT_TRUE(config.ParameterHasAlias(1, {}));
1256 
1257   auto alias_p0 = config.GetAliasedOutput(0, {});
1258   ASSERT_TRUE(alias_p0.has_value());
1259   EXPECT_EQ(*alias_p0, ShapeIndex({1}));
1260 
1261   auto alias_p1 = config.GetAliasedOutput(1, {});
1262   ASSERT_TRUE(alias_p1.has_value());
1263   EXPECT_EQ(*alias_p1, ShapeIndex({0}));
1264 }
1265 
ExpectAttributesMatch(const FrontendAttributes & attr,const FrontendAttributes & ref)1266 void ExpectAttributesMatch(const FrontendAttributes& attr,
1267                            const FrontendAttributes& ref) {
1268   EXPECT_EQ(ref.map_size(), attr.map_size());
1269   for (auto reference : ref.map()) {
1270     auto other = attr.map().find(reference.first);
1271     EXPECT_NE(other, attr.map().end());
1272     EXPECT_EQ(other->second, reference.second);
1273   }
1274 }
1275 
ExpectInstructionsAttributesMatch(const HloModule & module,const std::vector<FrontendAttributes> & expected)1276 void ExpectInstructionsAttributesMatch(
1277     const HloModule& module, const std::vector<FrontendAttributes>& expected) {
1278   ASSERT_EQ(module.computation_count(), 1);
1279   auto expected_it = expected.begin();
1280   for (auto inst : module.entry_computation()->instructions()) {
1281     ASSERT_NE(expected_it, expected.end());
1282     ExpectAttributesMatch(inst->frontend_attributes(), *expected_it);
1283     expected_it++;
1284   }
1285   EXPECT_EQ(expected_it, expected.end());
1286 }
1287 
TEST_F(XlaBuilderTest,SimpleSetFrontendAttributes)1288 TEST_F(XlaBuilderTest, SimpleSetFrontendAttributes) {
1289   XlaBuilder b(TestName());
1290   FrontendAttributes attributes;
1291 
1292   ConstantR0(&b, 0);  // No attribute set
1293 
1294   (*attributes.mutable_map())["attr_a"] = "a";
1295   b.SetFrontendAttributes(attributes);
1296   ConstantR0(&b, 0);  // One attribute: { "attr_a": "a" }
1297 
1298   b.ClearFrontendAttributes();
1299   ConstantR0(&b, 0);  // No attribute set
1300 
1301   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1302 
1303   std::vector<FrontendAttributes> expected{FrontendAttributes(), attributes,
1304                                            FrontendAttributes()};
1305   ExpectInstructionsAttributesMatch(*module, expected);
1306 }
1307 
TEST_F(XlaBuilderTest,ComplexSetFrontendAttributes)1308 TEST_F(XlaBuilderTest, ComplexSetFrontendAttributes) {
1309   XlaBuilder b(TestName());
1310 
1311   ConstantR0(&b, 0);  // No attribute set.
1312   std::vector<FrontendAttributes> expected{FrontendAttributes()};
1313 
1314   {
1315     FrontendAttributes attributes;
1316     (*attributes.mutable_map())["attr_a"] = "a";
1317     b.SetFrontendAttributes(attributes);
1318     ConstantR0(&b, 0);  // One attribute: { "attr_a": "a" }
1319     expected.push_back(attributes);
1320   }
1321 
1322   {
1323     FrontendAttributes attributes;
1324     (*attributes.mutable_map())["attr_b"] = "b";
1325     b.SetFrontendAttributes(attributes);
1326     ConstantR0(&b, 0);  // One attribute: { "attr_b": "b" }
1327     expected.push_back(attributes);
1328   }
1329 
1330   {
1331     FrontendAttributes attributes;
1332     (*attributes.mutable_map())["attr_b"] = "b";
1333     (*attributes.mutable_map())["attr_c"] = "c";
1334     b.SetFrontendAttributes(attributes);
1335     ConstantR0(&b, 0);  // Two attributes: { "attr_b": "b", "attr_c": "c" }
1336     expected.push_back(attributes);
1337   }
1338 
1339   b.ClearFrontendAttributes();
1340   ConstantR0(&b, 0);  // No attribute set
1341   expected.push_back(FrontendAttributes());
1342 
1343   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1344   ExpectInstructionsAttributesMatch(*module, expected);
1345 }
1346 
TEST_F(XlaBuilderTest,AddFrontendAttribute)1347 TEST_F(XlaBuilderTest, AddFrontendAttribute) {
1348   XlaBuilder b(TestName());
1349 
1350   ConstantR0(&b, 0);
1351   std::vector<FrontendAttributes> expected{FrontendAttributes()};
1352 
1353   // One attribute: { "attr_a": "a" }
1354   {
1355     FrontendAttributes attributes;
1356     (*attributes.mutable_map())["attr_a"] = "a";
1357     b.SetFrontendAttributes(attributes);
1358     ConstantR0(&b, 0);
1359     expected.push_back(attributes);
1360   }
1361 
1362   // Two attributes: {"attra": "a", "attr_c": "c"}
1363   {
1364     auto op = ConstantR0(&b, 0);
1365     EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_c", "c"));
1366 
1367     FrontendAttributes attributes;
1368     (*attributes.mutable_map())["attr_a"] = "a";
1369     (*attributes.mutable_map())["attr_c"] = "c";
1370     expected.push_back(attributes);
1371   }
1372 
1373   // Override value of existing "attr_a"
1374   // One attribute: { "attr_a", "a2"}
1375   {
1376     auto op = ConstantR0(&b, 0);
1377     EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_a", "a2"));
1378     FrontendAttributes attributes;
1379     (*attributes.mutable_map())["attr_a"] = "a2";
1380     expected.push_back(attributes);
1381   }
1382 
1383   // Check "attr_a" is back to its original value
1384   // One attribute: { "attr_a", "a"}
1385   {
1386     auto op = ConstantR0(&b, 0);
1387     (void)op;
1388     FrontendAttributes attributes;
1389     (*attributes.mutable_map())["attr_a"] = "a";
1390     expected.push_back(attributes);
1391   }
1392 
1393   b.ClearFrontendAttributes();
1394   ConstantR0(&b, 0);  // No attribute set
1395   expected.push_back(FrontendAttributes());
1396 
1397   // One attribute: { "attr_d", "d"}
1398   {
1399     auto op = ConstantR0(&b, 0);
1400     EXPECT_IS_OK(b.SetInstructionFrontendAttribute(op, "attr_d", "d"));
1401     FrontendAttributes attributes;
1402     (*attributes.mutable_map())["attr_d"] = "d";
1403     expected.push_back(attributes);
1404   }
1405 
1406   ConstantR0(&b, 0);  // No attribute set
1407   expected.push_back(FrontendAttributes());
1408 
1409   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1410   ExpectInstructionsAttributesMatch(*module, expected);
1411 }
1412 
TEST_F(XlaBuilderTest,ComparisonType)1413 TEST_F(XlaBuilderTest, ComparisonType) {
1414   XlaBuilder b(TestName());
1415   (void)Le(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
1416   TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
1417   auto root = module->entry_computation()->root_instruction();
1418   ASSERT_THAT(root, op::Compare(op::Constant(), op::Constant()));
1419   EXPECT_EQ(Comparison::Type::kSigned,
1420             DynCast<HloCompareInstruction>(root)->type());
1421 }
1422 
TEST_F(XlaBuilderTest,StableLookUpInstructionByHandle)1423 TEST_F(XlaBuilderTest, StableLookUpInstructionByHandle) {
1424   XlaBuilder b(TestName());
1425   internal::XlaBuilderFriend builder_friend;
1426   XlaOp le = Le(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
1427   HloInstructionProto* first_op = builder_friend.GetInstruction(le);
1428   // Create some more instructions.
1429   for (int i = 0; i < 100; ++i) {
1430     (void)Le(ConstantR0<int32_t>(&b, 1), ConstantR0<int32_t>(&b, 2));
1431   }
1432   // Make sure first_op hasn't changed.
1433   HloInstructionProto* first_op_now = builder_friend.GetInstruction(le);
1434   EXPECT_EQ(first_op, first_op_now);
1435 }
1436 
TEST_F(XlaBuilderTest,ComplexAbsConstant)1437 TEST_F(XlaBuilderTest, ComplexAbsConstant) {
1438   XlaBuilder b(TestName());
1439   XlaOp out =
1440       Abs(ConstantR0<std::complex<float>>(&b, std::complex<float>{-1, -1}));
1441   ValueInference value_inference(&b);
1442   StatusOr<OptionalLiteral> analyzed =
1443       value_inference.AnalyzeConstant(out, kUpperBound);
1444   EXPECT_IS_OK(analyzed.status());
1445   EXPECT_EQ(analyzed->GetValue().value().shape().element_type(),
1446             PrimitiveType::F32);
1447 }
1448 
1449 }  // namespace
1450 }  // namespace xla
1451