xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/conditional_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <random>
17 #include <utility>
18 
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/client/xla_computation.h"
21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
22 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
23 #include "tensorflow/compiler/xla/tests/test_macros.h"
24 
25 namespace xla {
26 namespace {
27 
28 class ConditionalOpTest : public ClientLibraryTestBase {
29  protected:
CreateR0ConstantComputation(float value)30   XlaComputation CreateR0ConstantComputation(float value) {
31     XlaBuilder builder("Constant");
32     Parameter(&builder, 0, empty_tuple_, "tuple");
33     ConstantR0<float>(&builder, value);
34     auto build_status = builder.Build();
35     EXPECT_IS_OK(build_status.status());
36     return std::move(build_status).value();
37   }
38 
CreateR0IdentityComputation()39   XlaComputation CreateR0IdentityComputation() {
40     XlaBuilder builder("Identity");
41     Parameter(&builder, 0, r0f32_, "x");
42     auto build_status = builder.Build();
43     EXPECT_IS_OK(build_status.status());
44     return std::move(build_status).value();
45   }
46 
CreateCeilComputation(const Shape & shape)47   XlaComputation CreateCeilComputation(const Shape& shape) {
48     XlaBuilder builder("Ceil");
49     auto param = Parameter(&builder, 0, shape, "param");
50     Ceil(param);
51     auto build_status = builder.Build();
52     EXPECT_IS_OK(build_status.status());
53     return std::move(build_status).value();
54   }
55 
CreateR0CeilComputation()56   XlaComputation CreateR0CeilComputation() {
57     return CreateCeilComputation(r0f32_);
58   }
59 
CreateR1CeilComputation()60   XlaComputation CreateR1CeilComputation() {
61     return CreateCeilComputation(r1s2f32_);
62   }
63 
CreateFloorComputation(const Shape & shape)64   XlaComputation CreateFloorComputation(const Shape& shape) {
65     XlaBuilder builder("Floor");
66     auto param = Parameter(&builder, 0, shape, "param");
67     Floor(param);
68     auto build_status = builder.Build();
69     EXPECT_IS_OK(build_status.status());
70     return std::move(build_status).value();
71   }
72 
CreateR0FloorComputation()73   XlaComputation CreateR0FloorComputation() {
74     return CreateFloorComputation(r0f32_);
75   }
76 
CreateR1FloorComputation()77   XlaComputation CreateR1FloorComputation() {
78     return CreateFloorComputation(r1s2f32_);
79   }
80 
CreateTupleCeilComputation(const std::string & computation_name,const Shape & tuple_shape)81   XlaComputation CreateTupleCeilComputation(const std::string& computation_name,
82                                             const Shape& tuple_shape) {
83     XlaBuilder builder(computation_name);
84     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
85     auto x = GetTupleElement(tuple, 0);
86     auto y = GetTupleElement(tuple, 1);
87     auto x_ceil = Ceil(x);
88     auto y_ceil = Ceil(y);
89     Tuple(&builder, {x_ceil, y_ceil});
90     auto build_status = builder.Build();
91     EXPECT_IS_OK(build_status.status());
92     return std::move(build_status).value();
93   }
94 
CreateR0TupleCeilComputation()95   XlaComputation CreateR0TupleCeilComputation() {
96     return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_);
97   }
98 
CreateR1TupleCeilComputation()99   XlaComputation CreateR1TupleCeilComputation() {
100     return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_);
101   }
102 
CreateTupleFloorComputation(const std::string & computation_name,const Shape & tuple_shape)103   XlaComputation CreateTupleFloorComputation(
104       const std::string& computation_name, const Shape& tuple_shape) {
105     XlaBuilder builder(computation_name);
106     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
107     auto x = GetTupleElement(tuple, 0);
108     auto y = GetTupleElement(tuple, 1);
109     auto x_floor = Floor(x);
110     auto y_floor = Floor(y);
111     Tuple(&builder, {x_floor, y_floor});
112     auto build_status = builder.Build();
113     EXPECT_IS_OK(build_status.status());
114     return std::move(build_status).value();
115   }
116 
CreateR0TupleFloorComputation()117   XlaComputation CreateR0TupleFloorComputation() {
118     return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_);
119   }
120 
CreateR1TupleFloorComputation()121   XlaComputation CreateR1TupleFloorComputation() {
122     return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_);
123   }
124 
CreateTupleAddComputation(const std::string & computation_name,const Shape & tuple_shape)125   XlaComputation CreateTupleAddComputation(const std::string& computation_name,
126                                            const Shape& tuple_shape) {
127     XlaBuilder builder(computation_name);
128     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
129     auto x = GetTupleElement(tuple, 0);
130     auto y = GetTupleElement(tuple, 1);
131     Add(x, y);
132     auto build_status = builder.Build();
133     EXPECT_IS_OK(build_status.status());
134     return std::move(build_status).value();
135   }
136 
CreateR0TupleAddComputation()137   XlaComputation CreateR0TupleAddComputation() {
138     return CreateTupleAddComputation("AddR0", tuple_2_r0f32_);
139   }
140 
CreateR1TupleAddComputation()141   XlaComputation CreateR1TupleAddComputation() {
142     return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_);
143   }
144 
CreateTupleSubComputation(const std::string & computation_name,const Shape & tuple_shape)145   XlaComputation CreateTupleSubComputation(const std::string& computation_name,
146                                            const Shape& tuple_shape) {
147     XlaBuilder builder(computation_name);
148     auto tuple = Parameter(&builder, 0, tuple_shape, "tuple");
149     auto x = GetTupleElement(tuple, 0);
150     auto y = GetTupleElement(tuple, 1);
151     Sub(x, y);
152     auto build_status = builder.Build();
153     EXPECT_IS_OK(build_status.status());
154     return std::move(build_status).value();
155   }
156 
CreateR0TupleSubComputation()157   XlaComputation CreateR0TupleSubComputation() {
158     return CreateTupleSubComputation("SubR0", tuple_2_r0f32_);
159   }
160 
CreateR1TupleSubComputation()161   XlaComputation CreateR1TupleSubComputation() {
162     return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_);
163   }
164 
165   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
166   Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
167   Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape(
168       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})});
169   Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape(
170       {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})});
171   Shape empty_tuple_ = ShapeUtil::MakeTupleShape({});
172   ErrorSpec error_spec_{0.001};
173 };
174 
175 // Test fixture to run indexed conditional (switch/case) tests with varying
176 // number of branches.
177 class CaseOpTest : public ConditionalOpTest,
178                    public ::testing::WithParamInterface<int> {};
179 
180 // Test true and false computations that do not take any parameters.
XLA_TEST_F(ConditionalOpTest,Parameters0)181 XLA_TEST_F(ConditionalOpTest, Parameters0) {
182   XlaBuilder builder(TestName());
183   XlaOp pred;
184   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
185   auto operands = Tuple(&builder, {});
186   auto true_computation = CreateR0ConstantComputation(56.0f);
187   auto false_computation = CreateR0ConstantComputation(12.0f);
188   Conditional(pred, operands, true_computation, operands, false_computation);
189 
190   ComputeAndCompareR0<float>(&builder, 56.0f, {pred_arg.get()}, error_spec_);
191 }
192 
193 // Test branch computations that do not take any parameters.
XLA_TEST_P(CaseOpTest,Parameters0)194 XLA_TEST_P(CaseOpTest, Parameters0) {
195   int num_branches = GetParam();
196   for (int bi = -1; bi <= num_branches; ++bi) {
197     SCOPED_TRACE(bi);
198     XlaBuilder builder(TestName());
199     XlaOp branch_index;
200     auto branch_index_arg = CreateR0Parameter<int32_t>(
201         bi, 0, "branch_index_arg", &builder, &branch_index);
202     auto operand = Tuple(&builder, {});
203 
204     std::vector<XlaOp> operands(num_branches, operand);
205     std::vector<XlaComputation> branches;
206     branches.reserve(num_branches);
207     std::vector<const XlaComputation*> branches_p(num_branches);
208     for (int i = 0; i < num_branches; ++i) {
209       branches.emplace_back(
210           CreateR0ConstantComputation(static_cast<float>(i) * 10));
211       branches_p[i] = &branches[i];
212     }
213     Conditional(branch_index, branches_p, operands);
214 
215     float expected = 10 * static_cast<float>((bi < 0 || bi >= num_branches)
216                                                  ? num_branches - 1
217                                                  : bi);
218     ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
219                                error_spec_);
220   }
221 }
222 
223 // Test true and false computations that take in 1 parameter.
XLA_TEST_F(ConditionalOpTest,Parameters1)224 XLA_TEST_F(ConditionalOpTest, Parameters1) {
225   XlaBuilder builder(TestName());
226   XlaOp pred;
227   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
228   auto operand1 = ConstantR0<float>(&builder, 56.0f);
229   auto operand2 = ConstantR0<float>(&builder, 12.0f);
230   auto identity = CreateR0IdentityComputation();
231   Conditional(pred, operand1, identity, operand2, identity);
232 
233   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
234 }
235 
236 // Test branch computations that take in 1 parameter.
XLA_TEST_P(CaseOpTest,Parameters1)237 XLA_TEST_P(CaseOpTest, Parameters1) {
238   int num_branches = GetParam();
239   for (int bi = -1; bi <= num_branches; ++bi) {
240     SCOPED_TRACE(bi);
241     XlaBuilder builder(TestName());
242     XlaOp branch_index;
243     auto branch_index_arg = CreateR0Parameter<int32_t>(
244         bi, 0, "branch_index_arg", &builder, &branch_index);
245 
246     auto make_branch = [&builder, this](int i) {
247       auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
248       Add(ConstantR0<float>(sb.get(), static_cast<float>(i)),
249           Parameter(sb.get(), 0, r0f32_, "p0"));
250       return sb->BuildAndNoteError();
251     };
252     std::vector<XlaComputation> branches;
253     branches.reserve(num_branches);
254     std::vector<const XlaComputation*> branches_p(num_branches);
255     std::vector<XlaOp> operands;
256     operands.reserve(num_branches);
257     std::vector<float> expecteds(num_branches);
258     for (int i = 0; i < num_branches; ++i) {
259       branches.emplace_back(make_branch(i));
260       branches_p[i] = &branches[i];
261       auto fi = static_cast<float>(i);
262       operands.emplace_back(ConstantR0<float>(&builder, 10 * fi + 7));
263       expecteds[i] = 10 * fi + 7 + fi;
264     }
265 
266     Conditional(branch_index, branches_p, operands);
267     float expected = (bi < 0 || bi >= num_branches)
268                          ? expecteds[num_branches - 1]
269                          : expecteds[bi];
270     ComputeAndCompareR0<float>(&builder, expected, {branch_index_arg.get()},
271                                error_spec_);
272   }
273 }
274 
275 // Test conditional with two different computations in the true and false cases
276 // that take in different arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsDiffArgs)277 XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) {
278   XlaBuilder builder(TestName());
279   XlaOp pred;
280   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
281   auto operand1 = ConstantR0<float>(&builder, 56.4f);
282   auto operand2 = ConstantR0<float>(&builder, 12.6f);
283   Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
284               CreateR0FloorComputation());
285 
286   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
287 }
288 
289 // Test conditional with two different computations in the true and false cases
290 // that take in the same arguments.
XLA_TEST_F(ConditionalOpTest,DiffComputationsSameArg)291 XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) {
292   XlaBuilder builder(TestName());
293   XlaOp pred;
294   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
295   auto operand = ConstantR0<float>(&builder, 12.6f);
296   Conditional(pred, operand, CreateR0CeilComputation(), operand,
297               CreateR0FloorComputation());
298 
299   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
300 }
301 
302 // Test conditional with the same computation in the true and false cases but
303 // take in different arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffArgs)304 XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) {
305   XlaBuilder builder(TestName());
306   XlaOp pred;
307   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
308   auto operand1 = ConstantR0<float>(&builder, 56.4f);
309   auto operand2 = ConstantR0<float>(&builder, 12.6f);
310   auto floor = CreateR0FloorComputation();
311   Conditional(pred, operand1, floor, operand2, floor);
312 
313   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
314 }
315 
316 // Test conditional with the same computation in the true and false cases that
317 // take in the same arguments.
XLA_TEST_F(ConditionalOpTest,SameComputationSameArg)318 XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) {
319   XlaBuilder builder(TestName());
320   XlaOp pred;
321   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
322   auto operand = ConstantR0<float>(&builder, 12.6f);
323   auto floor = CreateR0FloorComputation();
324   Conditional(pred, operand, floor, operand, floor);
325 
326   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
327 }
328 
329 // Test conditional with different instances of the same computation in the true
330 // and false cases.
XLA_TEST_F(ConditionalOpTest,SameComputationDiffInstances)331 XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) {
332   XlaBuilder builder(TestName());
333   XlaOp pred;
334   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
335   auto operand1 = ConstantR0<float>(&builder, 56.4f);
336   auto operand2 = ConstantR0<float>(&builder, 12.6f);
337   Conditional(pred, operand1, CreateR0FloorComputation(), operand2,
338               CreateR0FloorComputation());
339 
340   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
341 }
342 
343 // Test the case when a call invokes a computation that contains a conditional.
XLA_TEST_F(ConditionalOpTest,ConditionalWithCall)344 XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) {
345   Shape r0bool = ShapeUtil::MakeShape(PRED, {});
346   XlaBuilder inner_builder(TestName() + ".inner_conditional");
347   auto pred_cond = Parameter(&inner_builder, 0, r0bool, "param0");
348   auto true_operand = Parameter(&inner_builder, 1, r0f32_, "param1");
349   auto false_operand = Parameter(&inner_builder, 2, r0f32_, "param2");
350   Conditional(pred_cond, true_operand, CreateR0CeilComputation(), false_operand,
351               CreateR0FloorComputation());
352   auto inner_builder_result = inner_builder.Build().value();
353 
354   XlaBuilder builder(TestName());
355   XlaOp pred;
356   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
357   auto operand1 = ConstantR0<float>(&builder, 56.4f);
358   auto operand2 = ConstantR0<float>(&builder, 12.6f);
359   Call(&builder, inner_builder_result, {pred, operand1, operand2});
360 
361   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
362 }
363 
364 // Test true and false computations that take in 2 parameters and predicate is
365 // true.
XLA_TEST_F(ConditionalOpTest,Parameters2TrueBranch)366 XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) {
367   XlaBuilder builder(TestName());
368   XlaOp pred;
369   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
370   auto operand1 = ConstantR0<float>(&builder, 56.0f);
371   auto operand2 = ConstantR0<float>(&builder, 12.0f);
372   auto operands = Tuple(&builder, {operand1, operand2});
373   Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
374               CreateR0TupleSubComputation());
375 
376   ComputeAndCompareR0<float>(&builder, 68.0f, {pred_arg.get()}, error_spec_);
377 }
378 
379 // Test true and false computations that take in 2 parameters and predicate is
380 // false.
XLA_TEST_F(ConditionalOpTest,Parameters2FalseBranch)381 XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) {
382   XlaBuilder builder(TestName());
383   XlaOp pred;
384   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
385   auto operand1 = ConstantR0<float>(&builder, 56.0f);
386   auto operand2 = ConstantR0<float>(&builder, 12.0f);
387   auto operands = Tuple(&builder, {operand1, operand2});
388   Conditional(pred, operands, CreateR0TupleAddComputation(), operands,
389               CreateR0TupleSubComputation());
390 
391   ComputeAndCompareR0<float>(&builder, 44.0f, {pred_arg.get()}, error_spec_);
392 }
393 
394 // Test true and false computations that take in 2 array parameters and
395 // predicate is true.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayTrueBranch)396 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) {
397   XlaBuilder builder(TestName());
398   XlaOp pred;
399   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
400   auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
401   auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
402   auto operands = Tuple(&builder, {operand1, operand2});
403   Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
404               CreateR1TupleSubComputation());
405 
406   ComputeAndCompareR1<float>(&builder, {34.0f, 67.0f}, {pred_arg.get()},
407                              error_spec_);
408 }
409 
410 // Test branch computations that take in 2 array parameters.
XLA_TEST_P(CaseOpTest,Parameters2Array)411 XLA_TEST_P(CaseOpTest, Parameters2Array) {
412   int num_branches = GetParam();
413   for (int bi = -1; bi <= num_branches; ++bi) {
414     SCOPED_TRACE(bi);
415     XlaBuilder builder(TestName());
416     XlaOp branch_index;
417     auto branch_index_arg =
418         CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index);
419     auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
420     auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
421     auto operands = Tuple(&builder, {operand1, operand2});
422     auto make_branch = [&builder, this](int i) {
423       auto sb = builder.CreateSubBuilder(absl::StrCat("branch_", i));
424       auto p = Parameter(sb.get(), 0, tuple_2_r1s2f32_, "p0");
425       Add(Mul(ConstantR0<float>(sb.get(), static_cast<float>(i)),
426               GetTupleElement(p, 0)),
427           GetTupleElement(p, 1));
428       return sb->BuildAndNoteError();
429     };
430     std::vector<XlaComputation> branches;
431     branches.reserve(num_branches);
432     std::vector<const XlaComputation*> branches_p(num_branches);
433     for (int i = 0; i < num_branches; ++i) {
434       branches.emplace_back(make_branch(i));
435       branches_p[i] = &branches[i];
436     }
437     Conditional(branch_index, branches_p,
438                 std::vector<XlaOp>(num_branches, operands));
439     auto modified_bi = static_cast<float>(
440         (bi < 0 || bi >= num_branches) ? num_branches - 1 : bi);
441     ComputeAndCompareR1<float>(
442         &builder, {24.0f * modified_bi + 10, 56.0f * modified_bi + 11},
443         {branch_index_arg.get()}, error_spec_);
444   }
445 }
446 
447 INSTANTIATE_TEST_SUITE_P(CaseOpTest_Instantiation, CaseOpTest,
448                          ::testing::Values(1, 2, 3, 4, 5));
449 
450 // Test true and false computations that take in 2 array parameters and
451 // predicate is false.
XLA_TEST_F(ConditionalOpTest,Parameters2ArrayFalseBranch)452 XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) {
453   XlaBuilder builder(TestName());
454   XlaOp pred;
455   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
456   auto operand1 = ConstantR1<float>(&builder, {24.0f, 56.0f});
457   auto operand2 = ConstantR1<float>(&builder, {10.0f, 11.0f});
458   auto operands = Tuple(&builder, {operand1, operand2});
459   Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
460               CreateR1TupleSubComputation());
461 
462   ComputeAndCompareR1<float>(&builder, {14.0f, 45.0f}, {pred_arg.get()},
463                              error_spec_);
464 }
465 
466 // Test true and false computations that return a tuple of scalars.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfScalars)467 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) {
468   XlaBuilder builder(TestName());
469   XlaOp pred;
470   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
471   auto operands = Tuple(&builder, {ConstantR0<float>(&builder, 12.2f),
472                                    ConstantR0<float>(&builder, 25.6f)});
473   Conditional(pred, operands, CreateR0TupleCeilComputation(), operands,
474               CreateR0TupleFloorComputation());
475 
476   ComputeAndCompareTuple(
477       &builder,
478       LiteralUtil::MakeTupleFromSlices({LiteralUtil::CreateR0<float>(12.0f),
479                                         LiteralUtil::CreateR0<float>(25.0f)}),
480       {pred_arg.get()}, error_spec_);
481 }
482 
483 // Test true and false computations that return a tuple of arrays.
XLA_TEST_F(ConditionalOpTest,ReturnTupleOfArrays)484 XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) {
485   XlaBuilder builder(TestName());
486   XlaOp pred;
487   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
488   auto operands =
489       Tuple(&builder, {ConstantR1<float>(&builder, {12.2f, 15.8f}),
490                        ConstantR1<float>(&builder, {25.6f, 29.2f})});
491   Conditional(pred, operands, CreateR1TupleCeilComputation(), operands,
492               CreateR1TupleFloorComputation());
493 
494   ComputeAndCompareTuple(&builder,
495                          LiteralUtil::MakeTupleFromSlices(
496                              {LiteralUtil::CreateR1<float>({13.0f, 16.0f}),
497                               LiteralUtil::CreateR1<float>({26.0f, 30.0f})}),
498                          {pred_arg.get()}, error_spec_);
499 }
500 
501 // Test true and false computations that return a tuple of a predicate, a
502 // scalar, and an array.
XLA_TEST_F(ConditionalOpTest,ReturnTupleofPredicateScalarArray)503 XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) {
504   XlaBuilder true_builder(TestName() + ".true");
505   {
506     Parameter(&true_builder, 0, empty_tuple_, "tuple");
507     auto true_pred = ConstantR0<bool>(&true_builder, true);
508     auto true_scalar = ConstantR0<float>(&true_builder, 12.2f);
509     auto true_array = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
510     Tuple(&true_builder, {true_pred, true_scalar, true_array});
511   }
512   auto true_builder_result = true_builder.Build();
513   EXPECT_IS_OK(true_builder_result.status());
514 
515   XlaBuilder false_builder(TestName() + ".false");
516   {
517     Parameter(&false_builder, 0, empty_tuple_, "tuple");
518     auto false_pred = ConstantR0<bool>(&false_builder, false);
519     auto false_scalar = ConstantR0<float>(&false_builder, 25.6f);
520     auto false_array = ConstantR1<float>(&false_builder, {26.4f, 32.6f});
521     Tuple(&false_builder, {false_pred, false_scalar, false_array});
522   }
523   auto false_builder_result = false_builder.Build();
524   EXPECT_IS_OK(false_builder_result.status());
525 
526   XlaBuilder builder(TestName());
527   XlaOp pred;
528   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
529   auto operands = Tuple(&builder, {});
530   Conditional(pred, operands, std::move(true_builder_result).value(), operands,
531               std::move(false_builder_result).value());
532 
533   ComputeAndCompareTuple(&builder,
534                          LiteralUtil::MakeTupleFromSlices(
535                              {LiteralUtil::CreateR0<bool>(true),
536                               LiteralUtil::CreateR0<float>(12.2f),
537                               LiteralUtil::CreateR1<float>({12.8f, 14.6f})}),
538                          {pred_arg.get()}, error_spec_);
539 }
540 
541 // Test true and false computations that return a nested tuple.
XLA_TEST_F(ConditionalOpTest,ReturnNestedTuple)542 XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) {
543   XlaBuilder true_builder(TestName() + ".true");
544   {
545     Parameter(&true_builder, 0, empty_tuple_, "tuple");
546     auto true_constant1 = ConstantR0<float>(&true_builder, 12.2f);
547     auto true_constant2 = ConstantR1<float>(&true_builder, {12.8f, 14.6f});
548     auto true_constant3 = ConstantR1<float>(&true_builder, {25.4f, 29.8f});
549     auto true_constant4 = ConstantR0<float>(&true_builder, 35.6f);
550     Tuple(&true_builder,
551           {Tuple(&true_builder, {true_constant1, true_constant2}),
552            Tuple(&true_builder, {true_constant3, true_constant4})});
553   }
554   auto true_builder_result = true_builder.Build();
555   EXPECT_IS_OK(true_builder_result.status());
556 
557   XlaBuilder false_builder(TestName() + ".false");
558   {
559     Parameter(&false_builder, 0, empty_tuple_, "tuple");
560     auto false_constant1 = ConstantR0<float>(&false_builder, 46.6f);
561     auto false_constant2 = ConstantR1<float>(&false_builder, {54.4f, 58.4f});
562     auto false_constant3 = ConstantR1<float>(&false_builder, {62.1f, 67.4f});
563     auto false_constant4 = ConstantR0<float>(&false_builder, 9.3f);
564     Tuple(&false_builder,
565           {Tuple(&false_builder, {false_constant1, false_constant2}),
566            Tuple(&false_builder, {false_constant3, false_constant4})});
567   }
568   auto false_builder_result = false_builder.Build();
569   EXPECT_IS_OK(false_builder_result.status());
570 
571   XlaBuilder builder(TestName());
572   XlaOp pred;
573   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
574   auto operands = Tuple(&builder, {});
575   Conditional(pred, operands, std::move(true_builder_result).value(), operands,
576               std::move(false_builder_result).value());
577 
578   ComputeAndCompareTuple(
579       &builder,
580       LiteralUtil::MakeTupleFromSlices(
581           {LiteralUtil::MakeTupleFromSlices(
582                {LiteralUtil::CreateR0<float>(46.6f),
583                 LiteralUtil::CreateR1<float>({54.4f, 58.4f})}),
584            LiteralUtil::MakeTupleFromSlices(
585                {LiteralUtil::CreateR1<float>({62.1f, 67.4f}),
586                 LiteralUtil::CreateR0<float>(9.3f)})}),
587       {pred_arg.get()}, error_spec_);
588 }
589 
590 // Test conditional that takes in scalar operands in the form of external
591 // params.
XLA_TEST_F(ConditionalOpTest,ScalarOperandsFromExternalParams)592 XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) {
593   Shape r0bool = ShapeUtil::MakeShape(PRED, {});
594   XlaBuilder builder(TestName());
595 
596   XlaOp pred, operand1, operand2;
597   auto pred_arg = CreateR0Parameter<bool>(true, 0, "pred", &builder, &pred);
598   auto operand1_param =
599       CreateR0Parameter<float>(56.3f, 1, "operand1", &builder, &operand1);
600   auto operand2_param =
601       CreateR0Parameter<float>(12.7f, 2, "operand2", &builder, &operand2);
602   Conditional(pred, operand1, CreateR0CeilComputation(), operand2,
603               CreateR0FloorComputation());
604 
605   ComputeAndCompareR0<float>(
606       &builder, 57.0f,
607       {pred_arg.get(), operand1_param.get(), operand2_param.get()},
608       error_spec_);
609 }
610 
611 // Test conditional that takes in array operands in the form of external params.
XLA_TEST_F(ConditionalOpTest,ArrayOperandsFromExternalParams)612 XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) {
613   Shape r0bool = ShapeUtil::MakeShape(PRED, {});
614   XlaBuilder builder(TestName());
615 
616   XlaOp pred, operand1, operand2;
617   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
618   auto operand1_param = CreateR1Parameter<float>({24.3f, 56.7f}, 1, "operand1",
619                                                  &builder, &operand1);
620   auto operand2_param = CreateR1Parameter<float>({10.2f, 11.6f}, 2, "operand2",
621                                                  &builder, &operand2);
622   Conditional(pred, operand1, CreateR1CeilComputation(), operand2,
623               CreateR1FloorComputation());
624 
625   ComputeAndCompareR1<float>(
626       &builder, {10.0f, 11.0f},
627       {pred_arg.get(), operand1_param.get(), operand2_param.get()},
628       error_spec_);
629 }
630 
631 // Test the case where one conditional is nested within another.
XLA_TEST_F(ConditionalOpTest,NestedConditionals)632 XLA_TEST_F(ConditionalOpTest, NestedConditionals) {
633   XlaBuilder inner_builder(TestName() + ".inner_conditional");
634   {
635     Shape r0bool = ShapeUtil::MakeShape(PRED, {});
636     Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
637     auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
638     auto pred_cond = GetTupleElement(param0, 0);
639     auto true_operand = GetTupleElement(param0, 1);
640     auto false_operand = GetTupleElement(param0, 2);
641     Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
642                 false_operand, CreateR0FloorComputation());
643   }
644   auto inner_builder_result = inner_builder.Build();
645   EXPECT_IS_OK(inner_builder_result.status());
646 
647   XlaBuilder builder(TestName());
648   XlaOp pred1, pred2;
649   auto pred1_arg = CreateR0Parameter<bool>(true, 0, "pred1", &builder, &pred1);
650   auto pred2_arg = CreateR0Parameter<bool>(false, 1, "pred2", &builder, &pred2);
651   auto operand1 = ConstantR0<float>(&builder, 1.1f);
652   auto operand2 = ConstantR0<float>(&builder, 12.2f);
653   auto operand3 = ConstantR0<float>(&builder, 43.3f);
654   auto tuple_operand = Tuple(&builder, {pred2, operand1, operand2});
655   Conditional(pred1, tuple_operand, std::move(inner_builder_result).value(),
656               operand3, CreateR0IdentityComputation());
657 
658   ComputeAndCompareR0<float>(&builder, 12.0f,
659                              {pred1_arg.get(), pred2_arg.get()}, error_spec_);
660 }
661 
XLA_TEST_F(ConditionalOpTest,ConditionalInNestedComputation)662 XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) {
663   XlaBuilder inner_builder(TestName() + ".inner_conditional");
664   {
665     Shape r0bool = ShapeUtil::MakeShape(PRED, {});
666     Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_});
667     auto param0 = Parameter(&inner_builder, 0, tuple_shape, "param0");
668     auto pred_cond = GetTupleElement(param0, 0);
669     auto true_operand = GetTupleElement(param0, 1);
670     auto false_operand = GetTupleElement(param0, 2);
671     Conditional(pred_cond, true_operand, CreateR0CeilComputation(),
672                 false_operand, CreateR0FloorComputation());
673   }
674   auto inner_builder_result = inner_builder.Build();
675   EXPECT_IS_OK(inner_builder_result.status());
676 
677   XlaBuilder builder(TestName());
678   XlaOp pred;
679   auto pred_arg = CreateR0Parameter<bool>(false, 0, "pred", &builder, &pred);
680   auto operand1 = ConstantR0<float>(&builder, 1.1f);
681   auto operand2 = ConstantR0<float>(&builder, 12.2f);
682   auto tuple_operand = Tuple(&builder, {pred, operand1, operand2});
683   Call(&builder, std::move(inner_builder_result).value(), {tuple_operand});
684 
685   ComputeAndCompareR0<float>(&builder, 12.0f, {pred_arg.get()}, error_spec_);
686 }
687 
688 // Test a mismatch in the shape of the true operand and true computation.
XLA_TEST_F(ConditionalOpTest,ShapeMismatch)689 XLA_TEST_F(ConditionalOpTest, ShapeMismatch) {
690   XlaBuilder builder(TestName());
691   auto pred = ConstantR0<bool>(&builder, true);
692   auto operand1 = ConstantR0<float>(&builder, 56.0f);
693   auto operand2 = ConstantR0<float>(&builder, 12.0f);
694   auto operands = Tuple(&builder, {operand1, operand2});
695   Conditional(pred, operands, CreateR1TupleAddComputation(), operands,
696               CreateR0TupleSubComputation());
697 
698   auto result = builder.Build();
699   EXPECT_FALSE(result.ok());
700   EXPECT_THAT(result.status().error_message(),
701               ::testing::HasSubstr("operand 0 must match the shape of the "
702                                    "only parameter of branch computation 0"));
703 }
704 
XLA_TEST_F(ConditionalOpTest,SwappedInputsInSequentialConditionals)705 XLA_TEST_F(ConditionalOpTest, SwappedInputsInSequentialConditionals) {
706   Shape tuple_shape = ShapeUtil::MakeTupleShape({r0f32_, r0f32_});
707   XlaComputation swapper;
708   {
709     XlaBuilder builder(TestName() + ".swapper");
710     auto param0 = Parameter(&builder, 0, tuple_shape, "sp0");
711     auto x = GetTupleElement(param0, 0);
712     auto y = GetTupleElement(param0, 1);
713     Tuple(&builder, {y, x});
714     swapper = builder.Build().value();
715   }
716   XlaComputation forwarder;
717   {
718     XlaBuilder builder(TestName() + ".forwarder");
719     auto param0 = Parameter(&builder, 0, tuple_shape, "fp0");
720     auto x = GetTupleElement(param0, 0);
721     auto y = GetTupleElement(param0, 1);
722     Tuple(&builder, {x, y});
723     forwarder = builder.Build().value();
724   }
725   XlaComputation main;
726   {
727     XlaBuilder builder(TestName() + ".main");
728     auto param0 = Parameter(&builder, 0, tuple_shape, "mp0");
729     auto x = GetTupleElement(param0, 0);
730     auto y = GetTupleElement(param0, 1);
731     auto lt_pred = Lt(x, y);
732     auto res = Conditional(lt_pred, param0, forwarder, param0, swapper);
733     auto ge_pred = Ge(x, y);
734     Conditional(ge_pred, res, swapper, res, forwarder);
735     main = builder.Build().value();
736   }
737 
738   auto test_swap = [&](float a, float b) {
739     XlaBuilder builder(TestName());
740     XlaOp x, y;
741     auto x_arg = CreateR0Parameter<float>(a, 0, "x", &builder, &x);
742     auto y_arg = CreateR0Parameter<float>(b, 1, "y", &builder, &y);
743     auto tuple_operand = Tuple(&builder, {x, y});
744     Call(&builder, main, {tuple_operand});
745 
746     ComputeAndCompareTuple(
747         &builder,
748         LiteralUtil::MakeTupleFromSlices(
749             {LiteralUtil::CreateR0<float>(a), LiteralUtil::CreateR0<float>(b)}),
750         {x_arg.get(), y_arg.get()}, error_spec_);
751   };
752 
753   test_swap(3.11f, 9.4f);
754   test_swap(11.24f, 5.55f);
755 }
756 
757 // Test conditional that duplicates tuple elements in the then and else
758 // computations. This is a regression test for b/112550242.
XLA_TEST_F(ConditionalOpTest,DuplicateElementsConditional)759 XLA_TEST_F(ConditionalOpTest, DuplicateElementsConditional) {
760   const Shape scalar = ShapeUtil::MakeShape(S32, {});
761   const Shape tuple2 = ShapeUtil::MakeTupleShape({scalar, scalar});
762   XlaComputation then_comp;
763   {
764     XlaBuilder builder(TestName() + ".then");
765     auto p = Parameter(&builder, 0, tuple2, "then.p");
766     auto e0 = GetTupleElement(p, 0);
767     auto e1 = GetTupleElement(p, 1);
768     Tuple(&builder, {e0, e1, e0});
769     then_comp = builder.Build().value();
770   }
771   XlaComputation else_comp;
772   {
773     XlaBuilder builder(TestName() + ".else");
774     auto p = Parameter(&builder, 0, tuple2, "else.p");
775     auto e0 = GetTupleElement(p, 0);
776     auto e1 = GetTupleElement(p, 1);
777     Tuple(&builder, {e0, e1, e1});
778     else_comp = builder.Build().value();
779   }
780 
781   {
782     // Pred is true case.
783     std::vector<Literal> args;
784     args.push_back(LiteralUtil::MakeTupleFromSlices(
785         {LiteralUtil::CreateR0<int32_t>(123),
786          LiteralUtil::CreateR0<int32_t>(-42)}));
787     args.push_back(LiteralUtil::CreateR0<bool>(true));
788     XlaBuilder builder(TestName() + ".main");
789     auto p = Parameter(&builder, 0, tuple2, "p0");
790     auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
791     Conditional(p_pred, p, then_comp, p, else_comp);
792     ComputeAndCompare(&builder, args);
793   }
794   {
795     // Pred is false case.
796     std::vector<Literal> args;
797     args.push_back(LiteralUtil::MakeTupleFromSlices(
798         {LiteralUtil::CreateR0<int32_t>(123),
799          LiteralUtil::CreateR0<int32_t>(-42)}));
800     args.push_back(LiteralUtil::CreateR0<bool>(false));
801     XlaBuilder builder(TestName() + ".main");
802     auto p = Parameter(&builder, 0, tuple2, "p0");
803     auto p_pred = Parameter(&builder, 1, ShapeUtil::MakeShape(PRED, {}), "p1");
804     Conditional(p_pred, p, then_comp, p, else_comp);
805     ComputeAndCompare(&builder, args);
806   }
807 }
808 
809 }  // namespace
810 }  // namespace xla
811