xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/broadcast_simple_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <numeric>
18 #include <vector>
19 
20 #include "tensorflow/compiler/xla/array2d.h"
21 #include "tensorflow/compiler/xla/array4d.h"
22 #include "tensorflow/compiler/xla/client/local_client.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/literal_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
29 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
30 #include "tensorflow/compiler/xla/tests/test_macros.h"
31 
32 namespace xla {
33 namespace {
34 
35 class BroadcastSimpleTest : public ClientLibraryTestBase {
36  public:
BuildBinOp(HloOpcode op,const XlaOp lhs,const XlaOp rhs,XlaBuilder * builder)37   XlaOp BuildBinOp(HloOpcode op, const XlaOp lhs, const XlaOp rhs,
38                    XlaBuilder* builder) {
39     switch (op) {
40       case HloOpcode::kMinimum: {
41         return Min(lhs, rhs);
42       }
43       case HloOpcode::kMaximum: {
44         return Max(lhs, rhs);
45       }
46       case HloOpcode::kMultiply: {
47         return Mul(lhs, rhs);
48       }
49       default: {
50         // Default to Add
51         return Add(lhs, rhs);
52       }
53     }
54   }
55 
MakeR3Data(absl::Span<const int64_t> bounds,absl::Span<const int64_t> minor_to_major,Shape * r3_shape,Array3D<float> * r3_array,float start,float end,int seed)56   std::unique_ptr<GlobalData> MakeR3Data(
57       absl::Span<const int64_t> bounds,
58       absl::Span<const int64_t> minor_to_major, Shape* r3_shape,
59       Array3D<float>* r3_array, float start, float end, int seed) {
60     *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
61     r3_array->FillRandom(start, end, seed);
62     auto r3_data = LiteralUtil::CreateR3FromArray3D(*r3_array).Relayout(
63         LayoutUtil::MakeLayout(minor_to_major));
64     std::unique_ptr<GlobalData> r3_global_data =
65         client_->TransferToServer(r3_data).value();
66     return r3_global_data;
67   }
68 
MakeR2Data(absl::Span<const int64_t> bounds,absl::Span<const int64_t> minor_to_major,Shape * r2_shape,Array2D<float> * r2_array,float start,float end,int seed)69   std::unique_ptr<GlobalData> MakeR2Data(
70       absl::Span<const int64_t> bounds,
71       absl::Span<const int64_t> minor_to_major, Shape* r2_shape,
72       Array2D<float>* r2_array, float start, float end, int seed) {
73     *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major);
74     r2_array->FillRandom(start, end, seed);
75     auto r2_data = LiteralUtil::CreateR2FromArray2D(*r2_array).Relayout(
76         LayoutUtil::MakeLayout(minor_to_major));
77     std::unique_ptr<GlobalData> r2_global_data =
78         client_->TransferToServer(r2_data).value();
79     return r2_global_data;
80   }
81 
ApplyOpToFloats(HloOpcode op,float lhs,float rhs)82   float ApplyOpToFloats(HloOpcode op, float lhs, float rhs) {
83     switch (op) {
84       case HloOpcode::kMinimum: {
85         return std::min(lhs, rhs);
86       }
87       case HloOpcode::kMaximum: {
88         return std::max(lhs, rhs);
89       }
90       case HloOpcode::kMultiply: {
91         return lhs * rhs;
92       }
93       case HloOpcode::kAdd: {
94         return lhs + rhs;
95       }
96       default: {
97         // Default to Add
98         LOG(FATAL);
99       }
100     }
101   }
102 };
103 
104 using ::testing::HasSubstr;
105 
XLA_TEST_F(BroadcastSimpleTest,ScalarNoOpBroadcast)106 XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
107   XlaBuilder b(TestName());
108   Broadcast(ConstantR0<float>(&b, 1.5), {});
109   ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
110 }
111 
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x3)112 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
113   XlaBuilder b(TestName());
114   Broadcast(ConstantR0<float>(&b, 2.25), {2, 3});
115   Array2D<float> expected(2, 3, 2.25);
116   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
117 }
118 
XLA_TEST_F(BroadcastSimpleTest,ScalarParamTo2D_2x3)119 XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
120   XlaBuilder b(TestName());
121   XlaOp src;
122   std::unique_ptr<GlobalData> param_data =
123       CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
124                                /*builder=*/&b, /*data_handle=*/&src);
125 
126   Broadcast(src, {2, 3});
127   Array2D<float> expected(2, 3, 2.25);
128   ComputeAndCompareR2<float>(&b, expected, {param_data.get()},
129                              ErrorSpec(0.0001));
130 }
131 
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_2x0)132 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
133   XlaBuilder b(TestName());
134   Broadcast(ConstantR0<float>(&b, 2.25), {2, 0});
135   Array2D<float> expected(2, 0);
136   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
137 }
138 
XLA_TEST_F(BroadcastSimpleTest,ScalarTo2D_0x2)139 XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
140   XlaBuilder b(TestName());
141   Broadcast(ConstantR0<float>(&b, 2.25), {0, 2});
142   Array2D<float> expected(0, 2);
143   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
144 }
145 
146 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
147   XlaBuilder b(TestName());
148   Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {2});
149 
150   Array2D<float> expected(2, 3);
151   expected(0, 0) = 1;
152   expected(0, 1) = 2;
153   expected(0, 2) = 3;
154   expected(1, 0) = 1;
155   expected(1, 1) = 2;
156   expected(1, 2) = 3;
157   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
158 }
159 
160 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsUsual) {
161   XlaBuilder b(TestName());
162   BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {1});
163 
164   Array2D<float> expected(2, 2);
165   expected(0, 0) = 1;
166   expected(0, 1) = 2;
167   expected(1, 0) = 1;
168   expected(1, 1) = 2;
169 
170   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
171 }
172 
173 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsTranspose) {
174   XlaBuilder b(TestName());
175   BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {2, 2}, {0});
176 
177   Array2D<float> expected(2, 2);
178   expected(0, 0) = 1;
179   expected(0, 1) = 1;
180   expected(1, 0) = 2;
181   expected(1, 1) = 2;
182 
183   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
184 }
185 
186 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDims) {
187   XlaBuilder b(TestName());
188   BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
189                  {0, 1});
190 
191   Array3D<float> expected(2, 2, 2);
192   expected(0, 0, 0) = 1.0;
193   expected(1, 0, 0) = 2.0;
194   expected(0, 0, 1) = 1.0;
195   expected(1, 0, 1) = 2.0;
196   expected(0, 1, 0) = 5.0;
197   expected(1, 1, 0) = 6.0;
198   expected(1, 1, 1) = 6.0;
199   expected(0, 1, 1) = 5.0;
200 
201   ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
202 }
203 
204 XLA_TEST_F(BroadcastSimpleTest, 2DTo3D_WithDimsNotPossibleWithBroadCast) {
205   XlaBuilder b(TestName());
206   BroadcastInDim(ConstantR2<float>(&b, {{1.0, 5.0}, {2.0, 6.0}}), {2, 2, 2},
207                  {0, 2});
208 
209   Array3D<float> expected(2, 2, 2);
210   expected(0, 0, 0) = 1.0;
211   expected(1, 0, 0) = 2.0;
212   expected(0, 0, 1) = 5.0;
213   expected(1, 0, 1) = 6.0;
214   expected(0, 1, 0) = 1.0;
215   expected(1, 1, 0) = 2.0;
216   expected(1, 1, 1) = 6.0;
217   expected(0, 1, 1) = 5.0;
218 
219   ComputeAndCompareR3<float>(&b, expected, {}, ErrorSpec(0.0001));
220 }
221 
222 XLA_TEST_F(BroadcastSimpleTest, 1DTo2D_WithDimsNotPossibleWithBroadCast) {
223   XlaBuilder b(TestName());
224   BroadcastInDim(ConstantR1<float>(&b, {1, 2}), {3, 2}, {1});
225 
226   Array2D<float> expected(3, 2);
227   expected(0, 0) = 1;
228   expected(0, 1) = 2;
229   expected(1, 0) = 1;
230   expected(1, 1) = 2;
231   expected(2, 0) = 1;
232   expected(2, 1) = 2;
233 
234   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
235 }
236 
237 // Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest,BooleanAnd2DTo3D_Pred)238 XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
239   XlaBuilder b(TestName());
240 
241   Array2D<bool> x_vals(2, 1);
242   x_vals(0, 0) = true;
243   x_vals(1, 0) = false;
244   Array3D<bool> y_vals(2, 2, 1);
245   y_vals(0, 0, 0) = false;
246   y_vals(0, 1, 0) = false;
247   y_vals(1, 0, 0) = true;
248   y_vals(1, 1, 0) = true;
249 
250   XlaOp x, y;
251   auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
252   auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
253   And(x, y, /*broadcast_dimensions=*/{1, 2});
254 
255   Array3D<bool> expected(2, 2, 1);
256   expected(0, 0, 0) = false;
257   expected(0, 1, 0) = false;
258   expected(1, 0, 0) = true;
259   expected(1, 1, 0) = false;
260 
261   ComputeAndCompareR3<bool>(&b, expected, {x_data.get(), y_data.get()});
262 }
263 
XLA_TEST_F(BroadcastSimpleTest,ZeroElement_1DTo2D)264 XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
265   XlaBuilder b(TestName());
266   Broadcast(ConstantR1<float>(&b, {}), {2});
267 
268   Array2D<float> expected(2, 0);
269   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
270 }
271 
272 XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
273   XlaBuilder b(TestName());
274   Broadcast(ConstantR1<float>(&b, {1, 2, 3}), {0});
275 
276   Array2D<float> expected(0, 3);
277   ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
278 }
279 
XLA_TEST_F(BroadcastSimpleTest,InDimensionAndDegenerateBroadcasting)280 XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
281   // Verify that binary op and degenerate dimension broadcast work together in
282   // the same operation.
283   //
284   // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
285   // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
286   // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
287   // dimensions.
288   XlaBuilder b(TestName());
289 
290   Add(ConstantR2<float>(&b, {{1.0, 5.0}}),
291       ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
292                               {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
293       /*broadcast_dimensions=*/{1, 2});
294 
295   auto expected =
296       LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
297                                     {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
298 
299   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
300 }
301 
302 struct R3ImplicitBroadcastSpec {
303   std::array<int64_t, 3> output_bounds;
304   std::array<int64_t, 3> minor2major_layout;
305   std::array<int64_t, 3> input_bounds;
306   HloOpcode op;
307 } kR3ImplicitBroadcastTestCases[] = {
308     {{{1, 1, 1}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
309     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 5}}, HloOpcode::kMaximum},
310     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 1}}, HloOpcode::kMinimum},
311     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 1}}, HloOpcode::kMultiply},
312     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 1, 1}}, HloOpcode::kAdd},
313     {{{3, 4, 5}}, {{2, 1, 0}}, {{1, 4, 5}}, HloOpcode::kAdd},
314     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 4, 1}}, HloOpcode::kAdd},
315     {{{3, 4, 5}}, {{2, 1, 0}}, {{3, 1, 5}}, HloOpcode::kAdd},
316     {{{3, 199, 5}}, {{2, 1, 0}}, {{1, 199, 1}}, HloOpcode::kMinimum},
317     {{{3, 4, 199}}, {{2, 1, 0}}, {{1, 1, 199}}, HloOpcode::kAdd},
318 };
319 
320 class BroadcastR3ImplicitTest
321     : public BroadcastSimpleTest,
322       public ::testing::WithParamInterface<R3ImplicitBroadcastSpec> {};
323 
XLA_TEST_P(BroadcastR3ImplicitTest,Doit)324 XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
325   const R3ImplicitBroadcastSpec& spec = GetParam();
326   XlaBuilder builder(TestName());
327 
328   Shape r3_shape, r3_implicit_shape;
329   Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
330                           spec.output_bounds[2]);
331   Array3D<float> r3_implicit_array(spec.input_bounds[0], spec.input_bounds[1],
332                                    spec.input_bounds[2]);
333 
334   std::unique_ptr<GlobalData> r3_global_data =
335       MakeR3Data(spec.output_bounds, spec.minor2major_layout, &r3_shape,
336                  &r3_array, 1.0, 2.5, 56789);
337   std::unique_ptr<GlobalData> r3_implicit_global_data =
338       MakeR3Data(spec.input_bounds, spec.minor2major_layout, &r3_implicit_shape,
339                  &r3_implicit_array, 1.0, 0.2, 56789);
340 
341   auto r3_implicit_parameter =
342       Parameter(&builder, 0, r3_implicit_shape, "input");
343   auto r3_parameter = Parameter(&builder, 1, r3_shape, "input");
344   BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
345 
346   Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
347                                 spec.output_bounds[2]);
348   auto Each = ([&](absl::Span<const int64_t> indices, float* value) {
349     float r3_implicit = r3_implicit_array(indices[0] % spec.input_bounds[0],
350                                           indices[1] % spec.input_bounds[1],
351                                           indices[2] % spec.input_bounds[2]);
352     float r3 = r3_array(indices[0], indices[1], indices[2]);
353     *value = ApplyOpToFloats(spec.op, r3_implicit, r3);
354   });
355 
356   int n1 = expected_array.n1();
357   int n2 = expected_array.n2();
358   int n3 = expected_array.n3();
359   for (int64_t i = 0; i < n1; i++) {
360     for (int64_t j = 0; j < n2; j++) {
361       for (int64_t k = 0; k < n3; k++) {
362         Each({i, j, k}, &expected_array(i, j, k));
363       }
364     }
365   }
366   auto expected = LiteralUtil::CreateR3FromArray3D(expected_array);
367   ComputeAndCompareLiteral(
368       &builder, expected, {r3_implicit_global_data.get(), r3_global_data.get()},
369       ErrorSpec(1e-7, 1e-7));
370 }
371 
372 INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
373                         BroadcastR3ImplicitTest,
374                         ::testing::ValuesIn(kR3ImplicitBroadcastTestCases));
375 
376 // r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1_2)377 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
378   XlaBuilder b(TestName());
379   XlaOp r1h;
380   XlaOp r3h;
381 
382   Array3D<float> r1d = {{{1}}, {{2}}};
383   Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
384   auto r1 = CreateR3Parameter(r1d, 1, "r1", &b, &r1h);
385   auto r3 = CreateR3Parameter(r3d, 0, "r3", &b, &r3h);
386 
387   Add(r3h, r1h);
388 
389   auto expected =
390       LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}});
391 
392   ComputeAndCompareLiteral(&b, expected, {r3.get(), r1.get()},
393                            ErrorSpec(0.0001));
394 }
395 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1)396 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
397   XlaBuilder b(TestName());
398   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}}));
399   auto r3 = ConstantLiteral(
400       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
401   Add(r3, r1);
402 
403   auto expected =
404       LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}});
405 
406   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
407 }
408 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_2)409 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
410   XlaBuilder b(TestName());
411   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}, {2}}}));
412   auto r3 = ConstantLiteral(
413       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
414   Add(r3, r1);
415 
416   auto expected =
417       LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}});
418 
419   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
420 }
421 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0)422 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
423   XlaBuilder b(TestName());
424   auto r1 =
425       ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}}));
426   auto r3 = ConstantLiteral(
427       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
428   Add(r3, r1);
429 
430   auto expected =
431       LiteralUtil::CreateR3<float>({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}});
432 
433   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
434 }
435 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_1)436 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
437   XlaBuilder b(TestName());
438   auto r1 =
439       ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
440   auto r3 = ConstantLiteral(
441       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
442   Add(r3, r1);
443 
444   auto expected =
445       LiteralUtil::CreateR3<float>({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}});
446 
447   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
448 }
449 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_2)450 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
451   XlaBuilder b(TestName());
452   auto r1 = ConstantLiteral(
453       &b, LiteralUtil::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
454   auto r3 = ConstantLiteral(
455       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
456   Add(r3, r1);
457 
458   auto expected =
459       LiteralUtil::CreateR3<float>({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}});
460 
461   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
462 }
463 
XLA_TEST_F(BroadcastSimpleTest,Add3DTo3DDegenerate_0_1_2)464 XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
465   XlaBuilder b(TestName());
466   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR3<float>({{{1}}}));
467   auto r3 = ConstantLiteral(
468       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
469   Add(r3, r1);
470 
471   auto expected =
472       LiteralUtil::CreateR3<float>({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}});
473 
474   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
475 }
476 
477 struct R2ImplicitBroadcastSpec {
478   std::array<int64_t, 2> output_bounds;
479   std::array<int64_t, 2> minor2major_layout;
480   std::array<int64_t, 2> input_bounds1;
481   std::array<int64_t, 2> input_bounds2;
482   HloOpcode op1;
483   HloOpcode op2;
484 } kR2ImplicitBroadcastTestCases[] = {
485     {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
486     {{{2, 3}}, {{1, 0}}, {{2, 1}}, {{1, 3}}, HloOpcode::kAdd, HloOpcode::kAdd},
487     {{{2, 3}},
488      {{1, 0}},
489      {{2, 1}},
490      {{1, 1}},
491      HloOpcode::kAdd,
492      HloOpcode::kMinimum},
493     {{{2, 3}},
494      {{1, 0}},
495      {{1, 3}},
496      {{1, 1}},
497      HloOpcode::kAdd,
498      HloOpcode::kMinimum},
499     {{{2, 3}},
500      {{1, 0}},
501      {{1, 1}},
502      {{1, 1}},
503      HloOpcode::kAdd,
504      HloOpcode::kMinimum},
505     {{{2, 3}}, {{0, 1}}, {{2, 1}}, {{2, 1}}, HloOpcode::kAdd, HloOpcode::kAdd},
506     {{{150, 150}},
507      {{1, 0}},
508      {{150, 1}},
509      {{150, 1}},
510      HloOpcode::kAdd,
511      HloOpcode::kAdd},
512     {{{150, 150}},
513      {{1, 0}},
514      {{150, 1}},
515      {{1, 150}},
516      HloOpcode::kAdd,
517      HloOpcode::kAdd},
518     {{{150, 150}},
519      {{1, 0}},
520      {{150, 1}},
521      {{1, 1}},
522      HloOpcode::kAdd,
523      HloOpcode::kAdd},
524     {{{50, 150}},
525      {{1, 0}},
526      {{50, 1}},
527      {{50, 1}},
528      HloOpcode::kAdd,
529      HloOpcode::kAdd},
530     {{{50, 150}},
531      {{1, 0}},
532      {{50, 1}},
533      {{1, 150}},
534      HloOpcode::kAdd,
535      HloOpcode::kAdd},
536     {{{50, 150}},
537      {{1, 0}},
538      {{50, 1}},
539      {{1, 1}},
540      HloOpcode::kAdd,
541      HloOpcode::kAdd},
542     {{{150, 50}},
543      {{1, 0}},
544      {{150, 1}},
545      {{150, 1}},
546      HloOpcode::kAdd,
547      HloOpcode::kAdd},
548     {{{150, 50}},
549      {{1, 0}},
550      {{150, 1}},
551      {{1, 50}},
552      HloOpcode::kAdd,
553      HloOpcode::kAdd},
554     {{{150, 50}},
555      {{1, 0}},
556      {{150, 1}},
557      {{1, 1}},
558      HloOpcode::kAdd,
559      HloOpcode::kAdd}};
560 
561 class BroadcastR2ImplicitTest
562     : public BroadcastSimpleTest,
563       public ::testing::WithParamInterface<R2ImplicitBroadcastSpec> {};
564 
565 // Test r2 op1 r2_implicit_1 op2 r2_implicit_2
566 // where R2 is a rank-2 operand, and r2_implicit_2 are two
567 // rank-2 operands with degenerate dimensions:
XLA_TEST_P(BroadcastR2ImplicitTest,Doit)568 XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
569   const R2ImplicitBroadcastSpec& spec = GetParam();
570 
571   XlaBuilder builder(TestName());
572 
573   // Operands with degenerate dimensions require implicit broadcasting:
574   Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
575   Array2D<float> r2_array(spec.output_bounds[0], spec.output_bounds[1]);
576   Array2D<float> r2_implicit_array1(spec.input_bounds1[0],
577                                     spec.input_bounds1[1]);
578   Array2D<float> r2_implicit_array2(spec.input_bounds2[0],
579                                     spec.input_bounds2[1]);
580 
581   std::unique_ptr<GlobalData> r2_global_data =
582       MakeR2Data(spec.output_bounds, spec.minor2major_layout, &r2_shape,
583                  &r2_array, 1.0, 2.5, 56789);
584   std::unique_ptr<GlobalData> r2_implicit_global_data1 =
585       MakeR2Data(spec.input_bounds1, spec.minor2major_layout,
586                  &r2_implicit_shape1, &r2_implicit_array1, 1.0, 0.2, 56789);
587   std::unique_ptr<GlobalData> r2_implicit_global_data2 =
588       MakeR2Data(spec.input_bounds2, spec.minor2major_layout,
589                  &r2_implicit_shape2, &r2_implicit_array2, 0.8, 0.4, 56789);
590 
591   auto r2_implicit_parameter1 =
592       Parameter(&builder, 0, r2_implicit_shape1, "input0");
593   auto r2_parameter = Parameter(&builder, 1, r2_shape, "input1");
594   auto r2_implicit_parameter2 =
595       Parameter(&builder, 2, r2_implicit_shape2, "input2");
596 
597   XlaOp op1 =
598       BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
599   BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
600 
601   Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
602 
603   expected_array.Each([&](int64_t i, int64_t j, float* v) {
604     float v1 = r2_implicit_array1(i % spec.input_bounds1[0],
605                                   j % spec.input_bounds1[1]);
606     float v2 = r2_array(i, j);
607     float v3 = r2_implicit_array2(i % spec.input_bounds2[0],
608                                   j % spec.input_bounds2[1]);
609     float tmp = ApplyOpToFloats(spec.op1, v1, v2);
610     *v = ApplyOpToFloats(spec.op2, tmp, v3);
611   });
612 
613   auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
614   ComputeAndCompareLiteral(
615       &builder, expected,
616       {r2_implicit_global_data1.get(), r2_global_data.get(),
617        r2_implicit_global_data2.get()},
618       ErrorSpec(1e-6, 1e-6));
619 }
620 
621 INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
622                         BroadcastR2ImplicitTest,
623                         ::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
624 
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_0)625 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
626   XlaBuilder b(TestName());
627   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}}));
628   auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
629   Add(r2, r1);
630 
631   auto expected = LiteralUtil::CreateR2<float>({{2, 4}, {4, 6}});
632 
633   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
634 }
635 
XLA_TEST_F(BroadcastSimpleTest,Add2DTo2DDegenerate_1)636 XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
637   XlaBuilder b(TestName());
638   auto r1 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1}, {2}}));
639   auto r2 = ConstantLiteral(&b, LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}}));
640   Add(r2, r1);
641 
642   auto expected = LiteralUtil::CreateR2<float>({{2, 3}, {5, 6}});
643 
644   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
645 }
646 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim0)647 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
648   XlaBuilder b(TestName());
649   auto r1 = ConstantR1<float>(&b, {10, 20});
650   auto r3 = ConstantLiteral(
651       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
652   Add(r3, r1, {0});
653 
654   auto expected = LiteralUtil::CreateR3<float>(
655       {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}});
656 
657   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
658 }
659 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim1)660 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
661   XlaBuilder b(TestName());
662   auto r1 = ConstantR1<float>(&b, {10, 20});
663   auto r3 = ConstantLiteral(
664       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
665   Add(r1, r3, {1});
666 
667   auto expected = LiteralUtil::CreateR3<float>(
668       {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}});
669 
670   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
671 }
672 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDim2)673 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
674   XlaBuilder b(TestName());
675   auto r1 = ConstantR1<float>(&b, {10, 20});
676   auto r3 = ConstantLiteral(
677       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
678   Add(r1, r3, {2});
679 
680   auto expected = LiteralUtil::CreateR3<float>(
681       {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}});
682 
683   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
684 }
685 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAll)686 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
687   XlaBuilder b(TestName());
688   auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
689   auto r1_1 = ConstantR1<float>(&b, {100, 200});
690   auto r1_2 = ConstantR1<float>(&b, {10, 20});
691   auto r3 = ConstantLiteral(
692       &b, LiteralUtil::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
693   for (int i = 0; i < 3; ++i) {
694     r3 = Add(r1_0, r3, {0});
695     r3 = Add(r3, r1_1, {1});
696     r3 = Add(r1_2, r3, {2});
697   }
698   r3 = Mul(r3, ConstantR0<float>(&b, -2));
699 
700   auto expected = LiteralUtil::CreateR3<float>(
701       {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}},
702        {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}});
703 
704   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
705 }
706 
XLA_TEST_F(BroadcastSimpleTest,Add1DTo3DInDimAllWithScalarBroadcast)707 XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
708   XlaBuilder b(TestName());
709   auto r1_0 = ConstantR1<float>(&b, {1000, 2000});
710   auto r1_1 = ConstantR1<float>(&b, {100, 200});
711   auto r1_2 = ConstantR1<float>(&b, {10, 20});
712   auto r0 = ConstantR0<float>(&b, 3);
713   auto r3 = Broadcast(r0, {2, 2, 2});
714   for (int i = 0; i < 3; ++i) {
715     r3 = Add(r1_0, r3, {0});
716     r3 = Add(r3, r1_1, {1});
717     r3 = Add(r1_2, r3, {2});
718   }
719   r3 = Mul(r3, ConstantR0<float>(&b, -1));
720 
721   auto expected = LiteralUtil::CreateR3<float>(
722       {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}},
723        {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}});
724 
725   ComputeAndCompareLiteral(&b, expected, {}, ErrorSpec(0.0001));
726 }
727 
XLA_TEST_F(BroadcastSimpleTest,InvalidBinaryAndDegenerateBroadcasting)728 XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
729   // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
730   // results in a shape incompatible with the lhs [2, 3, 1].
731   XlaBuilder b(TestName());
732 
733   Add(ConstantR2<float>(&b, {{1.0, 5.0}, {1.0, 5.0}}),
734       ConstantLiteral(&b, LiteralUtil::CreateR3<float>(
735                               {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
736       /*broadcast_dimensions=*/{1, 2});
737 
738   auto result_status = Execute(&b, {});
739   EXPECT_FALSE(result_status.ok());
740   EXPECT_THAT(result_status.status().error_message(),
741               HasSubstr("dimension 0 mismatch"));
742 }
743 
XLA_TEST_F(BroadcastSimpleTest,InvalidInDimensionBroadcasting)744 XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
745   // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
746   XlaBuilder b(TestName());
747 
748   Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
749       ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
750 
751   auto result_status = Execute(&b, {});
752   EXPECT_FALSE(result_status.ok());
753   EXPECT_THAT(result_status.status().error_message(),
754               HasSubstr("op add with incompatible shapes"));
755 }
756 
XLA_TEST_F(BroadcastSimpleTest,InvalidDegenerateBroadcasting)757 XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
758   // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
759   XlaBuilder b(TestName());
760 
761   Add(ConstantR2<float>(&b, {{1.0, 2.0}}),
762       ConstantR2<float>(&b, {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
763 
764   auto result_status = Execute(&b, {});
765   EXPECT_FALSE(result_status.ok());
766   EXPECT_THAT(result_status.status().error_message(),
767               HasSubstr("op add with incompatible shapes"));
768 }
769 
770 }  // namespace
771 }  // namespace xla
772