1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array4d.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31
32 namespace xla {
33 namespace {
34
35 class BroadcastSimpleTest : public ClientLibraryTestBase {
36 public:
BuildBinOp(HloOpcode op,const XlaOp lhs,const XlaOp rhs,XlaBuilder * builder)37 XlaOp BuildBinOp(HloOpcode op, const XlaOp lhs, const XlaOp rhs,
38 XlaBuilder* builder) {
39 switch (op) {
40 case HloOpcode::kMinimum: {
41 return Min(lhs, rhs);
42 }
43 case HloOpcode::kMaximum: {
44 return Max(lhs, rhs);
45 }
46 case HloOpcode::kMultiply: {
47 return Mul(lhs, rhs);
48 }
49 default: {
50 // Default to Add
51 return Add(lhs, rhs);
52 }
53 }
54 }
55
MakeR3Data(absl::Span<const int64_t> bounds,absl::Span<const int64_t> minor_to_major,Shape * r3_shape,Array3D<float> * r3_array,float start,float end,int seed)56 std::unique_ptr<GlobalData> MakeR3Data(
57 absl::Span<const int64_t> bounds,
58 absl::Span<const int64_t> minor_to_major, Shape* r3_shape,
59 Array3D<float>* r3_array, float start, float end, int seed) {
60 *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
61 r3_array->FillRandom(start, end, seed);
62 auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
63 LayoutUtil::MakeLayout(minor_to_major));
64 std::unique_ptr<GlobalData> r3_global_data =
65 client_->TransferToServer(r3_data).value();
66 return r3_global_data;
67 }
68
MakeR2Data(absl::Span<const int64_t> bounds,absl::Span<const int64_t> minor_to_major,Shape * r2_shape,Array2D<float> * r2_array,float start,float end,int seed)69 std::unique_ptr<GlobalData> MakeR2Data(
70 absl::Span<const int64_t> bounds,
71 absl::Span<const int64_t> minor_to_major, Shape* r2_shape,
72 Array2D<float>* r2_array, float start, float end, int seed) {
73 *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
74 r2_array->FillRandom(start, end, seed);
75 auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
76 LayoutUtil::MakeLayout(minor_to_major));
77 std::unique_ptr<GlobalData> r2_global_data =
78 client_->TransferToServer(r2_data).value();
79 return r2_global_data;
80 }
81
ApplyOpToFloats(HloOpcode op,float lhs,float rhs)82 float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
83 switch (op) {
84 case HloOpcode::kMinimum: {
85 return std::min(lhs, rhs);
86 }
87 case HloOpcode::kMaximum: {
88 return std::max(lhs, rhs);
89 }
90 case HloOpcode::kMultiply: {
91 return lhs * rhs;
92 }
93 case HloOpcode::kAdd: {
94 return lhs + rhs;
95 }
96 default: {
97 // Default to Add
98 LOG(FATAL);
99 }
100 }
101 }
102 };
103
104 using ::testing::HasSubstr;
105
XLA_TEST_F(BroadcastSimpleTest,ScalarNoOpBroadcast)106 XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
107 XlaBuilder b(TestName());
108 Broadcast(ConstantR0<float>(&b, 1.5), {});
109 ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
110 }
111
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x3)112 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
113 XlaBuilder b(TestName());
114 Broadcast(ConstantR0<float>(&b, 2.25), {2, 3});
115 Array2D<float> expected(2, 3, 2.25);
116 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
117 }
118
XLA_TEST_F(BroadcastSimpleTest,ScalarParamTo2D_2x3)119 XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
120 XlaBuilder b(TestName());
121 XlaOp src;
122 std::unique_ptr<GlobalData> param_data =
123 CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
124 /*builder=*/&b, /*data_handle=*/&src);
125
126 Broadcast(src, {2, 3});
127 Array2D<float> expected(2, 3, 2.25);
128 ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
129 ErrorSpec(0.0001));
130 }
131
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x0)132 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
133 XlaBuilder b(TestName());
134 Broadcast(ConstantR0<float>(&b, 2.25), {2, 0});
135 Array2D<float> expected(2, 0);
136 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
137 }
138
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_0x2)139 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
140 XlaBuilder b(TestName());
141 Broadcast(ConstantR0<float>(&b, 2.25), {0, 2});
142 Array2D<float> expected(0, 2);
143 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
144 }
145
146 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
147 XlaBuilder b(TestName());
148 Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {2});
149
150 Array2D<float> expected(2, 3);
151 expected(0, 0) = 1;
152 expected(0, 1) = 2;
153 expected(0, 2) = 3;
154 expected(1, 0) = 1;
155 expected(1, 1) = 2;
156 expected(1, 2) = 3;
157 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
158 }
159
160 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
161 XlaBuilder b(TestName());
162 BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {1});
163
164 Array2D<float> expected(2, 2);
165 expected(0, 0) = 1;
166 expected(0, 1) = 2;
167 expected(1, 0) = 1;
168 expected(1, 1) = 2;
169
170 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
171 }
172
173 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
174 XlaBuilder b(TestName());
175 BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {0});
176
177 Array2D<float> expected(2, 2);
178 expected(0, 0) = 1;
179 expected(0, 1) = 1;
180 expected(1, 0) = 2;
181 expected(1, 1) = 2;
182
183 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
184 }
185
186 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
187 XlaBuilder b(TestName());
188 BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
189 {0, 1});
190
191 Array3D<float> expected(2, 2, 2);
192 expected(0, 0, 0) = 1.0;
193 expected(1, 0, 0) = 2.0;
194 expected(0, 0, 1) = 1.0;
195 expected(1, 0, 1) = 2.0;
196 expected(0, 1, 0) = 5.0;
197 expected(1, 1, 0) = 6.0;
198 expected(1, 1, 1) = 6.0;
199 expected(0, 1, 1) = 5.0;
200
201 ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
202 }
203
204 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
205 XlaBuilder b(TestName());
206 BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
207 {0, 2});
208
209 Array3D<float> expected(2, 2, 2);
210 expected(0, 0, 0) = 1.0;
211 expected(1, 0, 0) = 2.0;
212 expected(0, 0, 1) = 5.0;
213 expected(1, 0, 1) = 6.0;
214 expected(0, 1, 0) = 1.0;
215 expected(1, 1, 0) = 2.0;
216 expected(1, 1, 1) = 6.0;
217 expected(0, 1, 1) = 5.0;
218
219 ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
220 }
221
222 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
223 XlaBuilder b(TestName());
224 BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {3, 2}, {1});
225
226 Array2D<float> expected(3, 2);
227 expected(0, 0) = 1;
228 expected(0, 1) = 2;
229 expected(1, 0) = 1;
230 expected(1, 1) = 2;
231 expected(2, 0) = 1;
232 expected(2, 1) = 2;
233
234 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
235 }
236
237 // Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest,BooleanAnd2DTo3D_Pred)238 XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
239 XlaBuilder b(TestName());
240
241 Array2D<bool> x_vals(2, 1);
242 x_vals(0, 0) = true;
243 x_vals(1, 0) = false;
244 Array3D<bool> y_vals(2, 2, 1);
245 y_vals(0, 0, 0) = false;
246 y_vals(0, 1, 0) = false;
247 y_vals(1, 0, 0) = true;
248 y_vals(1, 1, 0) = true;
249
250 XlaOp x, y;
251 auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
252 auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
253 And(x, y, /*broadcast_dimensions=*/{1, 2});
254
255 Array3D<bool> expected(2, 2, 1);
256 expected(0, 0, 0) = false;
257 expected(0, 1, 0) = false;
258 expected(1, 0, 0) = true;
259 expected(1, 1, 0) = false;
260
261 ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
262 }
263
XLA_TEST_F(BroadcastSimpleTest,ZeroElement_1DTo2D)264 XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
265 XlaBuilder b(TestName());
266 Broadcast(ConstantR1<float>(&b, {}), {2});
267
268 Array2D<float> expected(2, 0);
269 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
270 }
271
272 XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
273 XlaBuilder b(TestName());
274 Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {0});
275
276 Array2D<float> expected(0, 3);
277 ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
278 }
279
XLA_TEST_F(BroadcastSimpleTest,InDimensionAndDegenerateBroadcasting)280 XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
281 // Verify that binary op and degenerate dimension broadcast work together in
282 // the same operation.
283 //
284 // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
285 // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
286 // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
287 // dimensions.
288 XlaBuilder b(TestName());
289
290 Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
291 ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
292 {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
293 /*broadcast_dimensions=*/{1, 2});
294
295 auto expected =
296 LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
297 {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
298
299 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
300 }
301
302 struct R3ImplicitBroadcastSpec {
303 std::array<int64_t, 3> output_bounds;
304 std::array<int64_t, 3> minor2major_layout;
305 std::array<int64_t, 3> input_bounds;
306 HloOpcode op;
307 } kR3ImplicitBroadcastTestCases[] = {
308 {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
309 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
310 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
311 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
312 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
313 {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
314 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
315 {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
316 {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
317 {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
318 };
319
320 class BroadcastR3ImplicitTest
321 : public BroadcastSimpleTest,
322 public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
323
XLA_TEST_P(BroadcastR3ImplicitTest,Doit)324 XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
325 const R3ImplicitBroadcastSpec& spec = GetParam();
326 XlaBuilder builder(TestName());
327
328 Shape r3_shape, r3_implicit_shape;
329 Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
330 spec.output_bounds[2]);
331 Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
332 spec.input_bounds[2]);
333
334 std::unique_ptr<GlobalData> r3_global_data =
335 MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
336 &r3_array, 1.0, 2.5, 56789);
337 std::unique_ptr<GlobalData> r3_implicit_global_data =
338 MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
339 &r3_implicit_array, 1.0, 0.2, 56789);
340
341 auto r3_implicit_parameter =
342 Parameter(&builder, 0, r3_implicit_shape, "input");
343 auto r3_parameter = Parameter(&builder, 1, r3_shape, "input");
344 BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
345
346 Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
347 spec.output_bounds[2]);
348 auto Each = ([&](absl::Span<const int64_t> indices, float* value) {
349 float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
350 indices[1] % spec.input_bounds[1],
351 indices[2] % spec.input_bounds[2]);
352 float r3 = r3_array(indices[0], indices[1], indices[2]);
353 *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
354 });
355
356 int n1 = expected_array.n1();
357 int n2 = expected_array.n2();
358 int n3 = expected_array.n3();
359 for (int64_t i = 0; i < n1; i++) {
360 for (int64_t j = 0; j < n2; j++) {
361 for (int64_t k = 0; k < n3; k++) {
362 Each({i, j, k}, &expected_array(i, j, k));
363 }
364 }
365 }
366 auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
367 ComputeAndCompareLiteral(
368 &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
369 ErrorSpec(1e-7, 1e-7));
370 }
371
372 INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
373 BroadcastR3ImplicitTest,
374 ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
375
376 // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1_2)377 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
378 XlaBuilder b(TestName());
379 XlaOp r1h;
380 XlaOp r3h;
381
382 Array3D<float> r1d = {{{1}}, {{2}}};
383 Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
384 auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
385 auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
386
387 Add(r3h, r1h);
388
389 auto expected =
390 LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
391
392 ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
393 ErrorSpec(0.0001));
394 }
395
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1)396 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
397 XlaBuilder b(TestName());
398 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
399 auto r3 = ConstantLiteral(
400 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
401 Add(r3, r1);
402
403 auto expected =
404 LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
405
406 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
407 }
408
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_2)409 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
410 XlaBuilder b(TestName());
411 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
412 auto r3 = ConstantLiteral(
413 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
414 Add(r3, r1);
415
416 auto expected =
417 LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
418
419 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
420 }
421
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0)422 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
423 XlaBuilder b(TestName());
424 auto r1 =
425 ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
426 auto r3 = ConstantLiteral(
427 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
428 Add(r3, r1);
429
430 auto expected =
431 LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
432
433 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
434 }
435
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1)436 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
437 XlaBuilder b(TestName());
438 auto r1 =
439 ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
440 auto r3 = ConstantLiteral(
441 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
442 Add(r3, r1);
443
444 auto expected =
445 LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
446
447 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
448 }
449
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_2)450 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
451 XlaBuilder b(TestName());
452 auto r1 = ConstantLiteral(
453 &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
454 auto r3 = ConstantLiteral(
455 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
456 Add(r3, r1);
457
458 auto expected =
459 LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
460
461 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
462 }
463
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1_2)464 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
465 XlaBuilder b(TestName());
466 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
467 auto r3 = ConstantLiteral(
468 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
469 Add(r3, r1);
470
471 auto expected =
472 LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
473
474 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
475 }
476
477 struct R2ImplicitBroadcastSpec {
478 std::array<int64_t, 2> output_bounds;
479 std::array<int64_t, 2> minor2major_layout;
480 std::array<int64_t, 2> input_bounds1;
481 std::array<int64_t, 2> input_bounds2;
482 HloOpcode op1;
483 HloOpcode op2;
484 } kR2ImplicitBroadcastTestCases[] = {
485 {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
486 {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
487 {{{2, 3}},
488 {{1, 0}},
489 {{2, 1}},
490 {{1, 1}},
491 HloOpcode::kAdd,
492 HloOpcode::kMinimum},
493 {{{2, 3}},
494 {{1, 0}},
495 {{1, 3}},
496 {{1, 1}},
497 HloOpcode::kAdd,
498 HloOpcode::kMinimum},
499 {{{2, 3}},
500 {{1, 0}},
501 {{1, 1}},
502 {{1, 1}},
503 HloOpcode::kAdd,
504 HloOpcode::kMinimum},
505 {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
506 {{{150, 150}},
507 {{1, 0}},
508 {{150, 1}},
509 {{150, 1}},
510 HloOpcode::kAdd,
511 HloOpcode::kAdd},
512 {{{150, 150}},
513 {{1, 0}},
514 {{150, 1}},
515 {{1, 150}},
516 HloOpcode::kAdd,
517 HloOpcode::kAdd},
518 {{{150, 150}},
519 {{1, 0}},
520 {{150, 1}},
521 {{1, 1}},
522 HloOpcode::kAdd,
523 HloOpcode::kAdd},
524 {{{50, 150}},
525 {{1, 0}},
526 {{50, 1}},
527 {{50, 1}},
528 HloOpcode::kAdd,
529 HloOpcode::kAdd},
530 {{{50, 150}},
531 {{1, 0}},
532 {{50, 1}},
533 {{1, 150}},
534 HloOpcode::kAdd,
535 HloOpcode::kAdd},
536 {{{50, 150}},
537 {{1, 0}},
538 {{50, 1}},
539 {{1, 1}},
540 HloOpcode::kAdd,
541 HloOpcode::kAdd},
542 {{{150, 50}},
543 {{1, 0}},
544 {{150, 1}},
545 {{150, 1}},
546 HloOpcode::kAdd,
547 HloOpcode::kAdd},
548 {{{150, 50}},
549 {{1, 0}},
550 {{150, 1}},
551 {{1, 50}},
552 HloOpcode::kAdd,
553 HloOpcode::kAdd},
554 {{{150, 50}},
555 {{1, 0}},
556 {{150, 1}},
557 {{1, 1}},
558 HloOpcode::kAdd,
559 HloOpcode::kAdd}};
560
561 class BroadcastR2ImplicitTest
562 : public BroadcastSimpleTest,
563 public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
564
565 // Test r2 op1 r2_implicit_1 op2 r2_implicit_2
566 // where R2 is a rank-2 operand, and r2_implicit_2 are two
567 // rank-2 operands with degenerate dimensions:
XLA_TEST_P(BroadcastR2ImplicitTest,Doit)568 XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
569 const R2ImplicitBroadcastSpec& spec = GetParam();
570
571 XlaBuilder builder(TestName());
572
573 // Operands with degenerate dimensions require implicit broadcasting:
574 Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
575 Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
576 Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
577 spec.input_bounds1[1]);
578 Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
579 spec.input_bounds2[1]);
580
581 std::unique_ptr<GlobalData> r2_global_data =
582 MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
583 &r2_array, 1.0, 2.5, 56789);
584 std::unique_ptr<GlobalData> r2_implicit_global_data1 =
585 MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
586 &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
587 std::unique_ptr<GlobalData> r2_implicit_global_data2 =
588 MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
589 &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
590
591 auto r2_implicit_parameter1 =
592 Parameter(&builder, 0, r2_implicit_shape1, "input0");
593 auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1");
594 auto r2_implicit_parameter2 =
595 Parameter(&builder, 2, r2_implicit_shape2, "input2");
596
597 XlaOp op1 =
598 BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
599 BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
600
601 Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
602
603 expected_array.Each([&](int64_t i, int64_t j, float* v) {
604 float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
605 j % spec.input_bounds1[1]);
606 float v2 = r2_array(i, j);
607 float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
608 j % spec.input_bounds2[1]);
609 float tmp = ApplyOpToFloats(spec.op1, v1, v2);
610 *v = ApplyOpToFloats(spec.op2, tmp, v3);
611 });
612
613 auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
614 ComputeAndCompareLiteral(
615 &builder, expected,
616 {r2_implicit_global_data1.get(), r2_global_data.get(),
617 r2_implicit_global_data2.get()},
618 ErrorSpec(1e-6, 1e-6));
619 }
620
621 INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
622 BroadcastR2ImplicitTest,
623 ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
624
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_0)625 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
626 XlaBuilder b(TestName());
627 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
628 auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
629 Add(r2, r1);
630
631 auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
632
633 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
634 }
635
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_1)636 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
637 XlaBuilder b(TestName());
638 auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
639 auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
640 Add(r2, r1);
641
642 auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
643
644 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
645 }
646
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim0)647 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
648 XlaBuilder b(TestName());
649 auto r1 = ConstantR1<float>(&b, {10, 20});
650 auto r3 = ConstantLiteral(
651 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
652 Add(r3, r1, {0});
653
654 auto expected = LiteralUtil::CreateR3<float>(
655 {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
656
657 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
658 }
659
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim1)660 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
661 XlaBuilder b(TestName());
662 auto r1 = ConstantR1<float>(&b, {10, 20});
663 auto r3 = ConstantLiteral(
664 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
665 Add(r1, r3, {1});
666
667 auto expected = LiteralUtil::CreateR3<float>(
668 {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
669
670 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
671 }
672
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim2)673 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
674 XlaBuilder b(TestName());
675 auto r1 = ConstantR1<float>(&b, {10, 20});
676 auto r3 = ConstantLiteral(
677 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
678 Add(r1, r3, {2});
679
680 auto expected = LiteralUtil::CreateR3<float>(
681 {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
682
683 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
684 }
685
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAll)686 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
687 XlaBuilder b(TestName());
688 auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
689 auto r1_1 = ConstantR1<float>(&b, {100, 200});
690 auto r1_2 = ConstantR1<float>(&b, {10, 20});
691 auto r3 = ConstantLiteral(
692 &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
693 for (int i = 0; i < 3; ++i) {
694 r3 = Add(r1_0, r3, {0});
695 r3 = Add(r3, r1_1, {1});
696 r3 = Add(r1_2, r3, {2});
697 }
698 r3 = Mul(r3, ConstantR0<float>(&b, -2));
699
700 auto expected = LiteralUtil::CreateR3<float>(
701 {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
702 {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
703
704 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
705 }
706
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAllWithScalarBroadcast)707 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
708 XlaBuilder b(TestName());
709 auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
710 auto r1_1 = ConstantR1<float>(&b, {100, 200});
711 auto r1_2 = ConstantR1<float>(&b, {10, 20});
712 auto r0 = ConstantR0<float>(&b, 3);
713 auto r3 = Broadcast(r0, {2, 2, 2});
714 for (int i = 0; i < 3; ++i) {
715 r3 = Add(r1_0, r3, {0});
716 r3 = Add(r3, r1_1, {1});
717 r3 = Add(r1_2, r3, {2});
718 }
719 r3 = Mul(r3, ConstantR0<float>(&b, -1));
720
721 auto expected = LiteralUtil::CreateR3<float>(
722 {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
723 {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
724
725 ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
726 }
727
XLA_TEST_F(BroadcastSimpleTest,InvalidBinaryAndDegenerateBroadcasting)728 XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
729 // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
730 // results in a shape incompatible with the lhs [2, 3, 1].
731 XlaBuilder b(TestName());
732
733 Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
734 ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
735 {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
736 /*broadcast_dimensions=*/{1, 2});
737
738 auto result_status = Execute(&b, {});
739 EXPECT_FALSE(result_status.ok());
740 EXPECT_THAT(result_status.status().error_message(),
741 HasSubstr("dimension 0 mismatch"));
742 }
743
XLA_TEST_F(BroadcastSimpleTest,InvalidInDimensionBroadcasting)744 XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
745 // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
746 XlaBuilder b(TestName());
747
748 Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
749 ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
750
751 auto result_status = Execute(&b, {});
752 EXPECT_FALSE(result_status.ok());
753 EXPECT_THAT(result_status.status().error_message(),
754 HasSubstr("op add with incompatible shapes"));
755 }
756
XLA_TEST_F(BroadcastSimpleTest,InvalidDegenerateBroadcasting)757 XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
758 // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
759 XlaBuilder b(TestName());
760
761 Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
762 ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
763
764 auto result_status = Execute(&b, {});
765 EXPECT_FALSE(result_status.ok());
766 EXPECT_THAT(result_status.status().error_message(),
767 HasSubstr("op add with incompatible shapes"));
768 }
769
770 } // namespace
771 } // namespace xla
772