1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/client/global_data.h"
20 #include "tensorflow/compiler/xla/client/local_client.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
23 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
24 #include "tensorflow/compiler/xla/tests/test_macros.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace xla {
28 namespace {
29
30 class SelectTest : public ClientLibraryTestBase {
31 public:
32 ErrorSpec error_spec_{0.0001};
33 };
34
TEST_F(SelectTest,SelectScalarF32True)35 TEST_F(SelectTest, SelectScalarF32True) {
36 XlaBuilder builder(TestName());
37 auto pred = ConstantR0<bool>(&builder, true);
38 auto on_true = ConstantR0<float>(&builder, 123.0f);
39 auto on_false = ConstantR0<float>(&builder, 42.0f);
40 Select(pred, on_true, on_false);
41
42 ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
43 }
44
TEST_F(SelectTest,SelectScalarS32True)45 TEST_F(SelectTest, SelectScalarS32True) {
46 XlaBuilder builder(TestName());
47 auto pred = ConstantR0<bool>(&builder, true);
48 auto on_true = ConstantR0<int32_t>(&builder, -42);
49 auto on_false = ConstantR0<int32_t>(&builder, 42);
50 Select(pred, on_true, on_false);
51
52 ComputeAndCompareR0<int32_t>(&builder, -42, {});
53 }
54
TEST_F(SelectTest,SelectScalarF32False)55 TEST_F(SelectTest, SelectScalarF32False) {
56 XlaBuilder builder(TestName());
57 auto pred = ConstantR0<bool>(&builder, false);
58 auto on_true = ConstantR0<float>(&builder, 123.0f);
59 auto on_false = ConstantR0<float>(&builder, 42.0f);
60 Select(pred, on_true, on_false);
61
62 ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
63 }
64
XLA_TEST_F(SelectTest,SelectR1S0F32WithConstantR1S0PRED)65 XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
66 XlaBuilder builder(TestName());
67 auto pred = ConstantR1<bool>(&builder, {});
68 auto on_true = ConstantR1<float>(&builder, {});
69 auto on_false = ConstantR1<float>(&builder, {});
70 Select(pred, on_true, on_false);
71
72 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
73 }
74
TEST_F(SelectTest,SelectR1F32WithConstantR1PRED)75 TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
76 XlaBuilder builder(TestName());
77 auto pred = ConstantR1<bool>(&builder, {false, true, false, true, false});
78 auto on_true =
79 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
80 auto on_false =
81 ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
82 Select(pred, on_true, on_false);
83
84 ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
85 error_spec_);
86 }
87
XLA_TEST_F(SelectTest,SelectR1S0F32WithCmpR1S0S32s)88 XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
89 // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
90 // is not a constant, but rather the result of comparing two other vectors.
91 XlaBuilder builder(TestName());
92 auto v1 = ConstantR1<int32_t>(&builder, {});
93 auto v2 = ConstantR1<int32_t>(&builder, {});
94 auto cmp = Eq(v1, v2);
95 auto on_true = ConstantR1<float>(&builder, {});
96 auto on_false = ConstantR1<float>(&builder, {});
97 Select(cmp, on_true, on_false);
98
99 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
100 }
101
TEST_F(SelectTest,SelectR1F32WithCmpR1S32s)102 TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
103 // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
104 // not a constant, but rather the result of comparing two other vectors.
105 XlaBuilder builder(TestName());
106 auto v1 = ConstantR1<int32_t>(&builder, {1, 2, 3, 4, 5});
107 auto v2 = ConstantR1<int32_t>(&builder, {9, 2, 9, 4, 9});
108 auto cmp = Eq(v1, v2);
109 auto on_true =
110 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
111 auto on_false =
112 ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
113 Select(cmp, on_true, on_false);
114
115 ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
116 error_spec_);
117 }
118
TEST_F(SelectTest,SelectR1F32WithCmpR1F32s)119 TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
120 // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
121 XlaBuilder builder(TestName());
122 auto v1 = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
123 auto v2 = ConstantR1<float>(&builder, {-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
124 auto cmp = Gt(v1, v2);
125 auto on_true =
126 ConstantR1<float>(&builder, {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
127 auto on_false =
128 ConstantR1<float>(&builder, {10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
129 Select(cmp, on_true, on_false);
130
131 ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
132 error_spec_);
133 }
134
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsSmall)135 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
136 // Selects among two R1F32s, which come from parameters. v1 and v2 are
137 // compared, and selection between them happens based on a gt-comparison mask.
138 XlaBuilder builder(TestName());
139
140 XlaOp v1, v2;
141 std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
142 {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
143 /*builder=*/&builder, /*data_handle=*/&v1);
144 std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
145 {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
146 /*builder=*/&builder, /*data_handle=*/&v2);
147
148 auto cmp = Gt(v1, v2);
149 Select(cmp, v1, v2);
150 ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
151 {param0_data.get(), param1_data.get()},
152 error_spec_);
153 }
154
TEST_F(SelectTest,SelectR1F32WithCmpR1F32sFromParamsLarge)155 TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
156 // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
157 // data size passed in and out is large.
158 XlaBuilder builder(TestName());
159
160 // Number of floats in the data passed into and out of the computation.
161 constexpr int datalen = 15 * 1000;
162
163 // The inputs are initialized with a special pattern where in the first third
164 // of the data v1[i] > v2[i] and elsewhere it's vice versa.
165 std::vector<float> v1vec;
166 std::vector<float> v2vec;
167 std::vector<float> expected_vec;
168 v1vec.reserve(datalen);
169 v2vec.reserve(datalen);
170 expected_vec.reserve(datalen);
171 for (int i = 0; i < datalen; ++i) {
172 float smaller = i;
173 float larger = i * 2;
174 if (i < datalen / 3) {
175 v1vec.push_back(larger);
176 v2vec.push_back(smaller);
177 } else {
178 v1vec.push_back(smaller);
179 v2vec.push_back(larger);
180 }
181 expected_vec.push_back(larger);
182 }
183
184 XlaOp v1, v2;
185 std::unique_ptr<GlobalData> param0_data =
186 CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
187 /*builder=*/&builder, /*data_handle=*/&v1);
188 std::unique_ptr<GlobalData> param1_data =
189 CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
190 /*builder=*/&builder, /*data_handle=*/&v2);
191
192 auto cmp = Gt(v1, v2);
193 Select(cmp, v1, v2);
194 ComputeAndCompareR1<float>(&builder, expected_vec,
195 {param0_data.get(), param1_data.get()},
196 error_spec_);
197 }
198
TEST_F(SelectTest,SelectR1F32WithCmpR1S32ToScalar)199 TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
200 // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
201 // select between two R1F32s.
202 XlaBuilder builder(TestName());
203 auto v = ConstantR1<int32_t>(&builder, {1, -1, 2, -2});
204 auto s = ConstantR0<int32_t>(&builder, 0);
205 auto cmp = Gt(v, s);
206
207 auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
208 auto on_false =
209 ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
210 Select(cmp, on_true, on_false);
211
212 ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
213 error_spec_);
214 }
215
TEST_F(SelectTest,SelectR1F32WithCmpR1F32ToScalar)216 TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
217 // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
218 // select between two R1F32s.
219 XlaBuilder builder(TestName());
220 auto v = ConstantR1<float>(&builder, {1.0f, 2.0f, 3.0f, 4.0f});
221 auto s = ConstantR0<float>(&builder, 2.5f);
222 auto cmp = Gt(v, s);
223
224 auto on_true = ConstantR1<float>(&builder, {11.0f, 22.0f, 33.0f, 44.0f});
225 auto on_false =
226 ConstantR1<float>(&builder, {-111.0f, -222.0f, -333.0f, -444.0f});
227 Select(cmp, on_true, on_false);
228
229 ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
230 error_spec_);
231 }
232
XLA_TEST_F(SelectTest,SelectR1S0F32WithScalarPredicate)233 XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
234 for (bool which : {false, true}) {
235 XlaBuilder builder(TestName());
236 auto pred = ConstantR0<bool>(&builder, which);
237 auto on_true = ConstantR1<float>(&builder, {});
238 auto on_false = ConstantR1<float>(&builder, {});
239 Select(pred, on_true, on_false);
240
241 ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
242 }
243 }
244
TEST_F(SelectTest,SelectR1F32WithScalarPredicateTrue)245 TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
246 XlaBuilder builder(TestName());
247 auto pred = ConstantR0<bool>(&builder, true);
248 auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
249 auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
250 Select(pred, on_true, on_false);
251
252 ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
253 }
254
TEST_F(SelectTest,SelectR1F32WithScalarPredicateFalse)255 TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
256 XlaBuilder builder(TestName());
257 auto pred = ConstantR0<bool>(&builder, false);
258 auto on_true = ConstantR1<float>(&builder, {-2.5f, 25.5f});
259 auto on_false = ConstantR1<float>(&builder, {10.0f, 5.0f});
260 Select(pred, on_true, on_false);
261
262 ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
263 }
264 } // namespace
265 } // namespace xla
266