1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <random>
17 #include <utility>
18
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
22 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
23 #include "tensorflow/compiler/xla/tests/test_macros.h"
24
25 namespace xla {
26 namespace {
27
28 class ConditionalOpTest : public ClientLibraryTestBase {
29 protected:
CreateR0ConstantComputation(float value)30 XlaComputation CreateR0ConstantComputation(float value) {
31 XlaBuilder builder("Constant");
32 Parameter(&builder, 0, empty_tuple_, "tuple");
33 ConstantR0<float>(&builder, value);
34 auto build_status = builder.Build();
35 EXPECT_IS_OK(build_status.status());
36 return std::move(build_status).value();
37 }
38
CreateR0IdentityComputation()39 XlaComputation CreateR0IdentityComputation() {
40 XlaBuilder builder("Identity");
41 Parameter(&builder, 0, r0f32_, "x");
42 auto build_status = builder.Build();
43 EXPECT_IS_OK(build_status.status());
44 return std::move(build_status).value();
45 }
46
CreateCeilComputation(const Shape & shape)47 XlaComputation CreateCeilComputation(const Shape& shape) {
48 XlaBuilder builder("Ceil");
49 auto param = Parameter(&builder, 0, shape, "param");
50 Ceil(param);
51 auto build_status = builder.Build();
52 EXPECT_IS_OK(build_status.status());
53 return std::move(build_status).value();
54 }
55
CreateR0CeilComputation()56 XlaComputation CreateR0CeilComputation() {
57 return CreateCeilComputation(r0f32_);
58 }
59
CreateR1CeilComputation()60 XlaComputation CreateR1CeilComputation() {
61 return CreateCeilComputation(r1s2f32_);
62 }
63
CreateFloorComputation(const Shape & shape)64 XlaComputation CreateFloorComputation(const Shape& shape) {
65 XlaBuilder builder("Floor");
66 auto param = Parameter(&builder, 0, shape, "param");
67 Floor(param);
68 auto build_status = builder.Build();
69 EXPECT_IS_OK(build_status.status());
70 return std::move(build_status).value();
71 }
72
CreateR0FloorComputation()73 XlaComputation CreateR0FloorComputation() {
74 return CreateFloorComputation(r0f32_);
75 }
76
CreateR1FloorComputation()77 XlaComputation CreateR1FloorComputation() {
78 return CreateFloorComputation(r1s2f32_);
79 }
80
CreateTupleCeilComputation(const std::string & computation_name,const Shape & tuple_shape)81 XlaComputation CreateTupleCeilComputation(const std::string& computation_name,
82 const Shape& tuple_shape) {
83 XlaBuilder builder(computation_name);
84 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
85 auto x = GetTupleElement(tuple, 0);
86 auto y = GetTupleElement(tuple, 1);
87 auto x_ceil = Ceil(x);
88 auto y_ceil = Ceil(y);
89 Tuple(&builder, {x_ceil, y_ceil});
90 auto build_status = builder.Build();
91 EXPECT_IS_OK(build_status.status());
92 return std::move(build_status).value();
93 }
94
CreateR0TupleCeilComputation()95 XlaComputation CreateR0TupleCeilComputation() {
96 return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
97 }
98
CreateR1TupleCeilComputation()99 XlaComputation CreateR1TupleCeilComputation() {
100 return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
101 }
102
CreateTupleFloorComputation(const std::string & computation_name,const Shape & tuple_shape)103 XlaComputation CreateTupleFloorComputation(
104 const std::string& computation_name, const Shape& tuple_shape) {
105 XlaBuilder builder(computation_name);
106 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
107 auto x = GetTupleElement(tuple, 0);
108 auto y = GetTupleElement(tuple, 1);
109 auto x_floor = Floor(x);
110 auto y_floor = Floor(y);
111 Tuple(&builder, {x_floor, y_floor});
112 auto build_status = builder.Build();
113 EXPECT_IS_OK(build_status.status());
114 return std::move(build_status).value();
115 }
116
CreateR0TupleFloorComputation()117 XlaComputation CreateR0TupleFloorComputation() {
118 return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
119 }
120
CreateR1TupleFloorComputation()121 XlaComputation CreateR1TupleFloorComputation() {
122 return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
123 }
124
CreateTupleAddComputation(const std::string & computation_name,const Shape & tuple_shape)125 XlaComputation CreateTupleAddComputation(const std::string& computation_name,
126 const Shape& tuple_shape) {
127 XlaBuilder builder(computation_name);
128 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
129 auto x = GetTupleElement(tuple, 0);
130 auto y = GetTupleElement(tuple, 1);
131 Add(x, y);
132 auto build_status = builder.Build();
133 EXPECT_IS_OK(build_status.status());
134 return std::move(build_status).value();
135 }
136
CreateR0TupleAddComputation()137 XlaComputation CreateR0TupleAddComputation() {
138 return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
139 }
140
CreateR1TupleAddComputation()141 XlaComputation CreateR1TupleAddComputation() {
142 return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
143 }
144
CreateTupleSubComputation(const std::string & computation_name,const Shape & tuple_shape)145 XlaComputation CreateTupleSubComputation(const std::string& computation_name,
146 const Shape& tuple_shape) {
147 XlaBuilder builder(computation_name);
148 auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
149 auto x = GetTupleElement(tuple, 0);
150 auto y = GetTupleElement(tuple, 1);
151 Sub(x, y);
152 auto build_status = builder.Build();
153 EXPECT_IS_OK(build_status.status());
154 return std::move(build_status).value();
155 }
156
CreateR0TupleSubComputation()157 XlaComputation CreateR0TupleSubComputation() {
158 return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
159 }
160
CreateR1TupleSubComputation()161 XlaComputation CreateR1TupleSubComputation() {
162 return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
163 }
164
165 Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
166 Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
167 Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
168 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
169 Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
170 {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})});
171 Shape empty_tuple_ = ShapeUtil::MakeTupleShape({});
172 ErrorSpec error_spec_{0.001};
173 };
174
175 // Test fixture to run indexed conditional (switch/case) tests with varying
176 // number of branches.
177 class CaseOpTest : public ConditionalOpTest,
178 public ::testing::WithParamInterface<int> {};
179
180 // Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest,Parameters0)181 XLA_TEST_F(ConditionalOpTest, Parameters0) {
182 XlaBuilder builder(TestName());
183 XlaOp pred;
184 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
185 auto operands = Tuple(&builder, {});
186 auto true_computation = CreateR0ConstantComputation(56.0f);
187 auto false_computation = CreateR0ConstantComputation(12.0f);
188 Conditional(pred, operands, true_computation, operands, false_computation);
189
190 ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
191 }
192
193 // Test branch computations that do not take any parameters.
XLA_TEST_P(CaseOpTest,Parameters0)194 XLA_TEST_P(CaseOpTest, Parameters0) {
195 int num_branches = GetParam();
196 for (int bi = -1; bi <= num_branches; ++bi) {
197 SCOPED_TRACE(bi);
198 XlaBuilder builder(TestName());
199 XlaOp branch_index;
200 auto branch_index_arg = CreateR0Parameter<int32_t>(
201 bi, 0, "branch_index_arg", &builder, &branch_index);
202 auto operand = Tuple(&builder, {});
203
204 std::vector<XlaOp> operands(num_branches, operand);
205 std::vector<XlaComputation> branches;
206 branches.reserve(num_branches);
207 std::vector<const XlaComputation*> branches_p(num_branches);
208 for (int i = 0; i < num_branches; ++i) {
209 branches.emplace_back(
210 CreateR0ConstantComputation(static_cast<float>(i) * 10));
211 branches_p[i] = &branches[i];
212 }
213 Conditional(branch_index, branches_p, operands);
214
215 float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
216 ? num_branches - 1
217 : bi);
218 ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
219 error_spec_);
220 }
221 }
222
223 // Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest,Parameters1)224 XLA_TEST_F(ConditionalOpTest, Parameters1) {
225 XlaBuilder builder(TestName());
226 XlaOp pred;
227 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
228 auto operand1 = ConstantR0<float>(&builder, 56.0f);
229 auto operand2 = ConstantR0<float>(&builder, 12.0f);
230 auto identity = CreateR0IdentityComputation();
231 Conditional(pred, operand1, identity, operand2, identity);
232
233 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
234 }
235
236 // Test branch computations that take in 1 parameter.
XLA_TEST_P(CaseOpTest,Parameters1)237 XLA_TEST_P(CaseOpTest, Parameters1) {
238 int num_branches = GetParam();
239 for (int bi = -1; bi <= num_branches; ++bi) {
240 SCOPED_TRACE(bi);
241 XlaBuilder builder(TestName());
242 XlaOp branch_index;
243 auto branch_index_arg = CreateR0Parameter<int32_t>(
244 bi, 0, "branch_index_arg", &builder, &branch_index);
245
246 auto make_branch = [&builder, this](int i) {
247 auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
248 Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
249 Parameter(sb.get(), 0, r0f32_, "p0"));
250 return sb->BuildAndNoteError();
251 };
252 std::vector<XlaComputation> branches;
253 branches.reserve(num_branches);
254 std::vector<const XlaComputation*> branches_p(num_branches);
255 std::vector<XlaOp> operands;
256 operands.reserve(num_branches);
257 std::vector<float> expecteds(num_branches);
258 for (int i = 0; i < num_branches; ++i) {
259 branches.emplace_back(make_branch(i));
260 branches_p[i] = &branches[i];
261 auto fi = static_cast<float>(i);
262 operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
263 expecteds[i] = 10 * fi + 7 + fi;
264 }
265
266 Conditional(branch_index, branches_p, operands);
267 float expected = (bi < 0 || bi >= num_branches)
268 ? expecteds[num_branches - 1]
269 : expecteds[bi];
270 ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
271 error_spec_);
272 }
273 }
274
275 // Test conditional with two different computations in the true and false cases
276 // that take in different arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsDiffArgs)277 XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
278 XlaBuilder builder(TestName());
279 XlaOp pred;
280 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
281 auto operand1 = ConstantR0<float>(&builder, 56.4f);
282 auto operand2 = ConstantR0<float>(&builder, 12.6f);
283 Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
284 CreateR0FloorComputation());
285
286 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
287 }
288
289 // Test conditional with two different computations in the true and false cases
290 // that take in the same arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsSameArg)291 XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
292 XlaBuilder builder(TestName());
293 XlaOp pred;
294 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
295 auto operand = ConstantR0<float>(&builder, 12.6f);
296 Conditional(pred, operand, CreateR0CeilComputation(), operand,
297 CreateR0FloorComputation());
298
299 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
300 }
301
302 // Test conditional with the same computation in the true and false cases but
303 // take in different arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffArgs)304 XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
305 XlaBuilder builder(TestName());
306 XlaOp pred;
307 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
308 auto operand1 = ConstantR0<float>(&builder, 56.4f);
309 auto operand2 = ConstantR0<float>(&builder, 12.6f);
310 auto floor = CreateR0FloorComputation();
311 Conditional(pred, operand1, floor, operand2, floor);
312
313 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
314 }
315
316 // Test conditional with the same computation in the true and false cases that
317 // take in the same arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationSameArg)318 XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
319 XlaBuilder builder(TestName());
320 XlaOp pred;
321 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
322 auto operand = ConstantR0<float>(&builder, 12.6f);
323 auto floor = CreateR0FloorComputation();
324 Conditional(pred, operand, floor, operand, floor);
325
326 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
327 }
328
329 // Test conditional with different instances of the same computation in the true
330 // and false cases.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffInstances)331 XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
332 XlaBuilder builder(TestName());
333 XlaOp pred;
334 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
335 auto operand1 = ConstantR0<float>(&builder, 56.4f);
336 auto operand2 = ConstantR0<float>(&builder, 12.6f);
337 Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
338 CreateR0FloorComputation());
339
340 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
341 }
342
343 // Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest,ConditionalWithCall)344 XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
345 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
346 XlaBuilder inner_builder(TestName() + ".inner_conditional");
347 auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
348 auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
349 auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
350 Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
351 CreateR0FloorComputation());
352 auto inner_builder_result = inner_builder.Build().value();
353
354 XlaBuilder builder(TestName());
355 XlaOp pred;
356 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
357 auto operand1 = ConstantR0<float>(&builder, 56.4f);
358 auto operand2 = ConstantR0<float>(&builder, 12.6f);
359 Call(&builder, inner_builder_result, {pred, operand1, operand2});
360
361 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
362 }
363
364 // Test true and false computations that take in 2 parameters and predicate is
365 // true.
XLA_TEST_F(ConditionalOpTest,Parameters2TrueBranch)366 XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
367 XlaBuilder builder(TestName());
368 XlaOp pred;
369 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
370 auto operand1 = ConstantR0<float>(&builder, 56.0f);
371 auto operand2 = ConstantR0<float>(&builder, 12.0f);
372 auto operands = Tuple(&builder, {operand1, operand2});
373 Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
374 CreateR0TupleSubComputation());
375
376 ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
377 }
378
379 // Test true and false computations that take in 2 parameters and predicate is
380 // false.
XLA_TEST_F(ConditionalOpTest,Parameters2FalseBranch)381 XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
382 XlaBuilder builder(TestName());
383 XlaOp pred;
384 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
385 auto operand1 = ConstantR0<float>(&builder, 56.0f);
386 auto operand2 = ConstantR0<float>(&builder, 12.0f);
387 auto operands = Tuple(&builder, {operand1, operand2});
388 Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
389 CreateR0TupleSubComputation());
390
391 ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
392 }
393
394 // Test true and false computations that take in 2 array parameters and
395 // predicate is true.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayTrueBranch)396 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
397 XlaBuilder builder(TestName());
398 XlaOp pred;
399 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
400 auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
401 auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
402 auto operands = Tuple(&builder, {operand1, operand2});
403 Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
404 CreateR1TupleSubComputation());
405
406 ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
407 error_spec_);
408 }
409
410 // Test branch computations that take in 2 array parameters.
XLA_TEST_P(CaseOpTest,Parameters2Array)411 XLA_TEST_P(CaseOpTest, Parameters2Array) {
412 int num_branches = GetParam();
413 for (int bi = -1; bi <= num_branches; ++bi) {
414 SCOPED_TRACE(bi);
415 XlaBuilder builder(TestName());
416 XlaOp branch_index;
417 auto branch_index_arg =
418 CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index);
419 auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
420 auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
421 auto operands = Tuple(&builder, {operand1, operand2});
422 auto make_branch = [&builder, this](int i) {
423 auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
424 auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
425 Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
426 GetTupleElement(p, 0)),
427 GetTupleElement(p, 1));
428 return sb->BuildAndNoteError();
429 };
430 std::vector<XlaComputation> branches;
431 branches.reserve(num_branches);
432 std::vector<const XlaComputation*> branches_p(num_branches);
433 for (int i = 0; i < num_branches; ++i) {
434 branches.emplace_back(make_branch(i));
435 branches_p[i] = &branches[i];
436 }
437 Conditional(branch_index, branches_p,
438 std::vector<XlaOp>(num_branches, operands));
439 auto modified_bi = static_cast<float>(
440 (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
441 ComputeAndCompareR1<float>(
442 &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
443 {branch_index_arg.get()}, error_spec_);
444 }
445 }
446
447 INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest,
448 ::testing::Values(1, 2, 3, 4, 5));
449
450 // Test true and false computations that take in 2 array parameters and
451 // predicate is false.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayFalseBranch)452 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
453 XlaBuilder builder(TestName());
454 XlaOp pred;
455 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
456 auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
457 auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
458 auto operands = Tuple(&builder, {operand1, operand2});
459 Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
460 CreateR1TupleSubComputation());
461
462 ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
463 error_spec_);
464 }
465
466 // Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfScalars)467 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
468 XlaBuilder builder(TestName());
469 XlaOp pred;
470 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
471 auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
472 ConstantR0<float>(&builder, 25.6f)});
473 Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
474 CreateR0TupleFloorComputation());
475
476 ComputeAndCompareTuple(
477 &builder,
478 LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
479 LiteralUtil::CreateR0<float>(25.0f)}),
480 {pred_arg.get()}, error_spec_);
481 }
482
483 // Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfArrays)484 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
485 XlaBuilder builder(TestName());
486 XlaOp pred;
487 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
488 auto operands =
489 Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
490 ConstantR1<float>(&builder, {25.6f, 29.2f})});
491 Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
492 CreateR1TupleFloorComputation());
493
494 ComputeAndCompareTuple(&builder,
495 LiteralUtil::MakeTupleFromSlices(
496 {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
497 LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
498 {pred_arg.get()}, error_spec_);
499 }
500
501 // Test true and false computations that return a tuple of a predicate, a
502 // scalar, and an array.
XLA_TEST_F(ConditionalOpTest,ReturnTupleofPredicateScalarArray)503 XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
504 XlaBuilder true_builder(TestName() + ".true");
505 {
506 Parameter(&true_builder, 0, empty_tuple_, "tuple");
507 auto true_pred = ConstantR0<bool>(&true_builder, true);
508 auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
509 auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
510 Tuple(&true_builder, {true_pred, true_scalar, true_array});
511 }
512 auto true_builder_result = true_builder.Build();
513 EXPECT_IS_OK(true_builder_result.status());
514
515 XlaBuilder false_builder(TestName() + ".false");
516 {
517 Parameter(&false_builder, 0, empty_tuple_, "tuple");
518 auto false_pred = ConstantR0<bool>(&false_builder, false);
519 auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
520 auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
521 Tuple(&false_builder, {false_pred, false_scalar, false_array});
522 }
523 auto false_builder_result = false_builder.Build();
524 EXPECT_IS_OK(false_builder_result.status());
525
526 XlaBuilder builder(TestName());
527 XlaOp pred;
528 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
529 auto operands = Tuple(&builder, {});
530 Conditional(pred, operands, std::move(true_builder_result).value(), operands,
531 std::move(false_builder_result).value());
532
533 ComputeAndCompareTuple(&builder,
534 LiteralUtil::MakeTupleFromSlices(
535 {LiteralUtil::CreateR0<bool>(true),
536 LiteralUtil::CreateR0<float>(12.2f),
537 LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
538 {pred_arg.get()}, error_spec_);
539 }
540
541 // Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest,ReturnNestedTuple)542 XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
543 XlaBuilder true_builder(TestName() + ".true");
544 {
545 Parameter(&true_builder, 0, empty_tuple_, "tuple");
546 auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
547 auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
548 auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
549 auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
550 Tuple(&true_builder,
551 {Tuple(&true_builder, {true_constant1, true_constant2}),
552 Tuple(&true_builder, {true_constant3, true_constant4})});
553 }
554 auto true_builder_result = true_builder.Build();
555 EXPECT_IS_OK(true_builder_result.status());
556
557 XlaBuilder false_builder(TestName() + ".false");
558 {
559 Parameter(&false_builder, 0, empty_tuple_, "tuple");
560 auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
561 auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
562 auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
563 auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
564 Tuple(&false_builder,
565 {Tuple(&false_builder, {false_constant1, false_constant2}),
566 Tuple(&false_builder, {false_constant3, false_constant4})});
567 }
568 auto false_builder_result = false_builder.Build();
569 EXPECT_IS_OK(false_builder_result.status());
570
571 XlaBuilder builder(TestName());
572 XlaOp pred;
573 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
574 auto operands = Tuple(&builder, {});
575 Conditional(pred, operands, std::move(true_builder_result).value(), operands,
576 std::move(false_builder_result).value());
577
578 ComputeAndCompareTuple(
579 &builder,
580 LiteralUtil::MakeTupleFromSlices(
581 {LiteralUtil::MakeTupleFromSlices(
582 {LiteralUtil::CreateR0<float>(46.6f),
583 LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
584 LiteralUtil::MakeTupleFromSlices(
585 {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
586 LiteralUtil::CreateR0<float>(9.3f)})}),
587 {pred_arg.get()}, error_spec_);
588 }
589
590 // Test conditional that takes in scalar operands in the form of external
591 // params.
XLA_TEST_F(ConditionalOpTest,ScalarOperandsFromExternalParams)592 XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
593 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
594 XlaBuilder builder(TestName());
595
596 XlaOp pred, operand1, operand2;
597 auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
598 auto operand1_param =
599 CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
600 auto operand2_param =
601 CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
602 Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
603 CreateR0FloorComputation());
604
605 ComputeAndCompareR0<float>(
606 &builder, 57.0f,
607 {pred_arg.get(), operand1_param.get(), operand2_param.get()},
608 error_spec_);
609 }
610
611 // Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest,ArrayOperandsFromExternalParams)612 XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
613 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
614 XlaBuilder builder(TestName());
615
616 XlaOp pred, operand1, operand2;
617 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
618 auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
619 &builder, &operand1);
620 auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
621 &builder, &operand2);
622 Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
623 CreateR1FloorComputation());
624
625 ComputeAndCompareR1<float>(
626 &builder, {10.0f, 11.0f},
627 {pred_arg.get(), operand1_param.get(), operand2_param.get()},
628 error_spec_);
629 }
630
631 // Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest,NestedConditionals)632 XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
633 XlaBuilder inner_builder(TestName() + ".inner_conditional");
634 {
635 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
636 Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
637 auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
638 auto pred_cond = GetTupleElement(param0, 0);
639 auto true_operand = GetTupleElement(param0, 1);
640 auto false_operand = GetTupleElement(param0, 2);
641 Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
642 false_operand, CreateR0FloorComputation());
643 }
644 auto inner_builder_result = inner_builder.Build();
645 EXPECT_IS_OK(inner_builder_result.status());
646
647 XlaBuilder builder(TestName());
648 XlaOp pred1, pred2;
649 auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
650 auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
651 auto operand1 = ConstantR0<float>(&builder, 1.1f);
652 auto operand2 = ConstantR0<float>(&builder, 12.2f);
653 auto operand3 = ConstantR0<float>(&builder, 43.3f);
654 auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
655 Conditional(pred1, tuple_operand, std::move(inner_builder_result).value(),
656 operand3, CreateR0IdentityComputation());
657
658 ComputeAndCompareR0<float>(&builder, 12.0f,
659 {pred1_arg.get(), pred2_arg.get()}, error_spec_);
660 }
661
XLA_TEST_F(ConditionalOpTest,ConditionalInNestedComputation)662 XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
663 XlaBuilder inner_builder(TestName() + ".inner_conditional");
664 {
665 Shape r0bool = ShapeUtil::MakeShape(PRED, {});
666 Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
667 auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
668 auto pred_cond = GetTupleElement(param0, 0);
669 auto true_operand = GetTupleElement(param0, 1);
670 auto false_operand = GetTupleElement(param0, 2);
671 Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
672 false_operand, CreateR0FloorComputation());
673 }
674 auto inner_builder_result = inner_builder.Build();
675 EXPECT_IS_OK(inner_builder_result.status());
676
677 XlaBuilder builder(TestName());
678 XlaOp pred;
679 auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
680 auto operand1 = ConstantR0<float>(&builder, 1.1f);
681 auto operand2 = ConstantR0<float>(&builder, 12.2f);
682 auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
683 Call(&builder, std::move(inner_builder_result).value(), {tuple_operand});
684
685 ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
686 }
687
688 // Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest,ShapeMismatch)689 XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
690 XlaBuilder builder(TestName());
691 auto pred = ConstantR0<bool>(&builder, true);
692 auto operand1 = ConstantR0<float>(&builder, 56.0f);
693 auto operand2 = ConstantR0<float>(&builder, 12.0f);
694 auto operands = Tuple(&builder, {operand1, operand2});
695 Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
696 CreateR0TupleSubComputation());
697
698 auto result = builder.Build();
699 EXPECT_FALSE(result.ok());
700 EXPECT_THAT(result.status().error_message(),
701 ::testing::HasSubstr("operand 0 must match the shape of the "
702 "only parameter of branch computation 0"));
703 }
704
XLA_TEST_F(ConditionalOpTest,SwappedInputsInSequentialConditionals)705 XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
706 Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
707 XlaComputation swapper;
708 {
709 XlaBuilder builder(TestName() + ".swapper");
710 auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
711 auto x = GetTupleElement(param0, 0);
712 auto y = GetTupleElement(param0, 1);
713 Tuple(&builder, {y, x});
714 swapper = builder.Build().value();
715 }
716 XlaComputation forwarder;
717 {
718 XlaBuilder builder(TestName() + ".forwarder");
719 auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
720 auto x = GetTupleElement(param0, 0);
721 auto y = GetTupleElement(param0, 1);
722 Tuple(&builder, {x, y});
723 forwarder = builder.Build().value();
724 }
725 XlaComputation main;
726 {
727 XlaBuilder builder(TestName() + ".main");
728 auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
729 auto x = GetTupleElement(param0, 0);
730 auto y = GetTupleElement(param0, 1);
731 auto lt_pred = Lt(x, y);
732 auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
733 auto ge_pred = Ge(x, y);
734 Conditional(ge_pred, res, swapper, res, forwarder);
735 main = builder.Build().value();
736 }
737
738 auto test_swap = [&](float a, float b) {
739 XlaBuilder builder(TestName());
740 XlaOp x, y;
741 auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
742 auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
743 auto tuple_operand = Tuple(&builder, {x, y});
744 Call(&builder, main, {tuple_operand});
745
746 ComputeAndCompareTuple(
747 &builder,
748 LiteralUtil::MakeTupleFromSlices(
749 {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
750 {x_arg.get(), y_arg.get()}, error_spec_);
751 };
752
753 test_swap(3.11f, 9.4f);
754 test_swap(11.24f, 5.55f);
755 }
756
757 // Test conditional that duplicates tuple elements in the then and else
758 // computations. This is a regression test for b/112550242.
XLA_TEST_F(ConditionalOpTest,DuplicateElementsConditional)759 XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
760 const Shape scalar = ShapeUtil::MakeShape(S32, {});
761 const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar});
762 XlaComputation then_comp;
763 {
764 XlaBuilder builder(TestName() + ".then");
765 auto p = Parameter(&builder, 0, tuple2, "then.p");
766 auto e0 = GetTupleElement(p, 0);
767 auto e1 = GetTupleElement(p, 1);
768 Tuple(&builder, {e0, e1, e0});
769 then_comp = builder.Build().value();
770 }
771 XlaComputation else_comp;
772 {
773 XlaBuilder builder(TestName() + ".else");
774 auto p = Parameter(&builder, 0, tuple2, "else.p");
775 auto e0 = GetTupleElement(p, 0);
776 auto e1 = GetTupleElement(p, 1);
777 Tuple(&builder, {e0, e1, e1});
778 else_comp = builder.Build().value();
779 }
780
781 {
782 // Pred is true case.
783 std::vector<Literal> args;
784 args.push_back(LiteralUtil::MakeTupleFromSlices(
785 {LiteralUtil::CreateR0<int32_t>(123),
786 LiteralUtil::CreateR0<int32_t>(-42)}));
787 args.push_back(LiteralUtil::CreateR0<bool>(true));
788 XlaBuilder builder(TestName() + ".main");
789 auto p = Parameter(&builder, 0, tuple2, "p0");
790 auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
791 Conditional(p_pred, p, then_comp, p, else_comp);
792 ComputeAndCompare(&builder, args);
793 }
794 {
795 // Pred is false case.
796 std::vector<Literal> args;
797 args.push_back(LiteralUtil::MakeTupleFromSlices(
798 {LiteralUtil::CreateR0<int32_t>(123),
799 LiteralUtil::CreateR0<int32_t>(-42)}));
800 args.push_back(LiteralUtil::CreateR0<bool>(false));
801 XlaBuilder builder(TestName() + ".main");
802 auto p = Parameter(&builder, 0, tuple2, "p0");
803 auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
804 Conditional(p_pred, p, then_comp, p, else_comp);
805 ComputeAndCompare(&builder, args);
806 }
807 }
808
809 } // namespace
810 } // namespace xla
811