xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/select_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 namespace xla {
28 namespace {
29 
30 class SelectTest : public ClientLibraryTestBase {
31  public:
32   ErrorSpec error_spec_{0.0001};
33 };
34 
TEST_F(SelectTest,SelectScalarF32True)35 TEST_F(SelectTest, SelectScalarF32True) {
36   XlaBuilder builder(TestName());
37   auto pred = ConstantR0<bool>(&builder, true);
38   auto on_true = ConstantR0<float>(&builder, 123.0f);
39   auto on_false = ConstantR0<float>(&builder, 42.0f);
40   Select(pred, on_true, on_false);
41 
42   ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
43 }
44 
TEST_F(SelectTest,SelectScalarS32True)45 TEST_F(SelectTest, SelectScalarS32True) {
46   XlaBuilder builder(TestName());
47   auto pred = ConstantR0<bool>(&builder, true);
48   auto on_true = ConstantR0<int32_t>(&builder, -42);
49   auto on_false = ConstantR0<int32_t>(&builder, 42);
50   Select(pred, on_true, on_false);
51 
52   ComputeAndCompareR0<int32_t>(&builder, -42, {});
53 }
54 
TEST_F(SelectTest,SelectScalarF32False)55 TEST_F(SelectTest, SelectScalarF32False) {
56   XlaBuilder builder(TestName());
57   auto pred = ConstantR0<bool>(&builder, false);
58   auto on_true = ConstantR0<float>(&builder, 123.0f);
59   auto on_false = ConstantR0<float>(&builder, 42.0f);
60   Select(pred, on_true, on_false);
61 
62   ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
63 }
64 
XLA_TEST_F(SelectTest,SelectR1S0F32WithConstantR1S0PRED)65 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
66   XlaBuilder builder(TestName());
67   auto pred = ConstantR1<bool>(&builder, {});
68   auto on_true = ConstantR1<float>(&builder, {});
69   auto on_false = ConstantR1<float>(&builder, {});
70   Select(pred, on_true, on_false);
71 
72   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
73 }
74 
TEST_F(SelectTest,SelectR1F32WithConstantR1PRED)75 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
76   XlaBuilder builder(TestName());
77   auto pred = ConstantR1<bool>(&builder, {false, true, false, true, false});
78   auto on_true =
79       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
80   auto on_false =
81       ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
82   Select(pred, on_true, on_false);
83 
84   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
85                              error_spec_);
86 }
87 
XLA_TEST_F(SelectTest,SelectR1S0F32WithCmpR1S0S32s)88 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
89   // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
90   // is not a constant, but rather the result of comparing two other vectors.
91   XlaBuilder builder(TestName());
92   auto v1 = ConstantR1<int32_t>(&builder, {});
93   auto v2 = ConstantR1<int32_t>(&builder, {});
94   auto cmp = Eq(v1, v2);
95   auto on_true = ConstantR1<float>(&builder, {});
96   auto on_false = ConstantR1<float>(&builder, {});
97   Select(cmp, on_true, on_false);
98 
99   ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
100 }
101 
TEST_F(SelectTest,SelectR1F32WithCmpR1S32s)102 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
103   // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
104   // not a constant, but rather the result of comparing two other vectors.
105   XlaBuilder builder(TestName());
106   auto v1 = ConstantR1<int32_t>(&builder, {1, 2, 3, 4, 5});
107   auto v2 = ConstantR1<int32_t>(&builder, {9, 2, 9, 4, 9});
108   auto cmp = Eq(v1, v2);
109   auto on_true =
110       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
111   auto on_false =
112       ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
113   Select(cmp, on_true, on_false);
114 
115   ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
116                              error_spec_);
117 }
118 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32s)119 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
120   // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
121   XlaBuilder builder(TestName());
122   auto v1 = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
123   auto v2 = ConstantR1<float>(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
124   auto cmp = Gt(v1, v2);
125   auto on_true =
126       ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
127   auto on_false =
128       ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
129   Select(cmp, on_true, on_false);
130 
131   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
132                              error_spec_);
133 }
134 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsSmall)135 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
136   // Selects among two R1F32s, which come from parameters. v1 and v2 are
137   // compared, and selection between them happens based on a gt-comparison mask.
138   XlaBuilder builder(TestName());
139 
140   XlaOp v1, v2;
141   std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
142       {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
143       /*builder=*/&builder, /*data_handle=*/&v1);
144   std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
145       {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
146       /*builder=*/&builder, /*data_handle=*/&v2);
147 
148   auto cmp = Gt(v1, v2);
149   Select(cmp, v1, v2);
150   ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
151                              {param0_data.get(), param1_data.get()},
152                              error_spec_);
153 }
154 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsLarge)155 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
156   // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
157   // data size passed in and out is large.
158   XlaBuilder builder(TestName());
159 
160   // Number of floats in the data passed into and out of the computation.
161   constexpr int datalen = 15 * 1000;
162 
163   // The inputs are initialized with a special pattern where in the first third
164   // of the data v1[i] > v2[i] and elsewhere it's vice versa.
165   std::vector<float> v1vec;
166   std::vector<float> v2vec;
167   std::vector<float> expected_vec;
168   v1vec.reserve(datalen);
169   v2vec.reserve(datalen);
170   expected_vec.reserve(datalen);
171   for (int i = 0; i < datalen; ++i) {
172     float smaller = i;
173     float larger = i * 2;
174     if (i < datalen / 3) {
175       v1vec.push_back(larger);
176       v2vec.push_back(smaller);
177     } else {
178       v1vec.push_back(smaller);
179       v2vec.push_back(larger);
180     }
181     expected_vec.push_back(larger);
182   }
183 
184   XlaOp v1, v2;
185   std::unique_ptr<GlobalData> param0_data =
186       CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
187                                /*builder=*/&builder, /*data_handle=*/&v1);
188   std::unique_ptr<GlobalData> param1_data =
189       CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
190                                /*builder=*/&builder, /*data_handle=*/&v2);
191 
192   auto cmp = Gt(v1, v2);
193   Select(cmp, v1, v2);
194   ComputeAndCompareR1<float>(&builder, expected_vec,
195                              {param0_data.get(), param1_data.get()},
196                              error_spec_);
197 }
198 
TEST_F(SelectTest,SelectR1F32WithCmpR1S32ToScalar)199 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
200   // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
201   // select between two R1F32s.
202   XlaBuilder builder(TestName());
203   auto v = ConstantR1<int32_t>(&builder, {1, -1, 2, -2});
204   auto s = ConstantR0<int32_t>(&builder, 0);
205   auto cmp = Gt(v, s);
206 
207   auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
208   auto on_false =
209       ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
210   Select(cmp, on_true, on_false);
211 
212   ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
213                              error_spec_);
214 }
215 
TEST_F(SelectTest,SelectR1F32WithCmpR1F32ToScalar)216 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
217   // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
218   // select between two R1F32s.
219   XlaBuilder builder(TestName());
220   auto v = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f});
221   auto s = ConstantR0<float>(&builder, 2.5f);
222   auto cmp = Gt(v, s);
223 
224   auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
225   auto on_false =
226       ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
227   Select(cmp, on_true, on_false);
228 
229   ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
230                              error_spec_);
231 }
232 
XLA_TEST_F(SelectTest,SelectR1S0F32WithScalarPredicate)233 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
234   for (bool which : {false, true}) {
235     XlaBuilder builder(TestName());
236     auto pred = ConstantR0<bool>(&builder, which);
237     auto on_true = ConstantR1<float>(&builder, {});
238     auto on_false = ConstantR1<float>(&builder, {});
239     Select(pred, on_true, on_false);
240 
241     ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
242   }
243 }
244 
TEST_F(SelectTest,SelectR1F32WithScalarPredicateTrue)245 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
246   XlaBuilder builder(TestName());
247   auto pred = ConstantR0<bool>(&builder, true);
248   auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
249   auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
250   Select(pred, on_true, on_false);
251 
252   ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
253 }
254 
TEST_F(SelectTest,SelectR1F32WithScalarPredicateFalse)255 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
256   XlaBuilder builder(TestName());
257   auto pred = ConstantR0<bool>(&builder, false);
258   auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
259   auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
260   Select(pred, on_true, on_false);
261 
262   ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
263 }
264 }  // namespace
265 }  // namespace xla
266