xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/reference_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/reference_util.h"
17 
18 #include <array>
19 #include <memory>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
Array2DF32ToF64(const Array2D<float> & input)35 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
36     const Array2D<float>& input) {
37   auto result =
38       std::make_unique<Array2D<double>>(input.height(), input.width());
39   for (int64_t rowno = 0; rowno < input.height(); ++rowno) {
40     for (int64_t colno = 0; colno < input.height(); ++colno) {
41       (*result)(rowno, colno) = input(rowno, colno);
42     }
43   }
44   return result;
45 }
46 
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding)47 /*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
48     const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
49     Padding padding) {
50   return ConvArray3DGeneralDimensionsDilated(
51       lhs, rhs, kernel_stride, padding, 1, 1,
52       XlaBuilder::CreateDefaultConvDimensionNumbers(1));
53 }
54 
55 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding,int64_t lhs_dilation,int64_t rhs_dilation,const ConvolutionDimensionNumbers & dnums)56 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
57     const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
58     Padding padding, int64_t lhs_dilation, int64_t rhs_dilation,
59     const ConvolutionDimensionNumbers& dnums) {
60   CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
61   CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
62   CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
63   // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
64   // array by adding a fourth dummy dimension of size 1 without stride, padding
65   // and dilation.
66   Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
67   a4dlhs.Each([&](absl::Span<const int64_t> indices, float* value_ptr) {
68     CHECK_EQ(indices[3], 0);
69     *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
70   });
71   Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
72   a4drhs.Each([&](absl::Span<const int64_t> indices, float* value_ptr) {
73     CHECK_EQ(indices[3], 0);
74     *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
75   });
76   // Add a second dummy spatial dimensions.
77   ConvolutionDimensionNumbers dnums2d = dnums;
78   dnums2d.add_input_spatial_dimensions(3);
79   dnums2d.add_kernel_spatial_dimensions(3);
80   dnums2d.add_output_spatial_dimensions(3);
81   std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
82       a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
83       {rhs_dilation, 1}, dnums2d);
84 
85   auto convr3 = std::make_unique<Array3D<float>>(
86       convr4->planes(), convr4->depth(), convr4->height());
87   convr4->Each([&](absl::Span<const int64_t> indices, float* value_ptr) {
88     CHECK_EQ(indices[3], 0);
89     convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
90   });
91   return convr3;
92 }
93 
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64_t,int64_t> kernel_stride,Padding padding)94 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
95     const Array4D<float>& lhs, const Array4D<float>& rhs,
96     std::pair<int64_t, int64_t> kernel_stride, Padding padding) {
97   return ConvArray4DGeneralDimensions(
98       lhs, rhs, kernel_stride, padding,
99       XlaBuilder::CreateDefaultConvDimensionNumbers());
100 }
101 
102 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64_t,int64_t> kernel_stride,Padding padding)103 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
104                                     const Array4D<float>& depthwise_weights,
105                                     const Array4D<float>& pointwise_weights,
106                                     std::pair<int64_t, int64_t> kernel_stride,
107                                     Padding padding) {
108   const int64_t depth_multiplier = depthwise_weights.planes();
109   CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
110 
111   // Combine the two weights by reducing the depth_multiplier, so that we can
112   // apply a single convolution on the combined weights.
113   Array4D<float> weights(pointwise_weights.planes(), input.depth(),
114                          depthwise_weights.height(), depthwise_weights.width());
115   for (int64_t kx = 0; kx < depthwise_weights.width(); ++kx) {
116     for (int64_t ky = 0; ky < depthwise_weights.height(); ++ky) {
117       for (int64_t kz = 0; kz < input.depth(); ++kz) {
118         for (int64_t out = 0; out < pointwise_weights.planes(); ++out) {
119           float weight = 0.0;
120           for (int64_t depth = 0; depth < depth_multiplier; ++depth) {
121             weight +=
122                 depthwise_weights(depth, kz, ky, kx) *
123                 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
124           }
125           weights(out, kz, ky, kx) = weight;
126         }
127       }
128     }
129   }
130 
131   return ConvArray4D(input, weights, kernel_stride, padding);
132 }
133 
WindowCount(int64_t unpadded_width,int64_t window_len,int64_t stride,Padding padding)134 /* static */ int64_t ReferenceUtil::WindowCount(int64_t unpadded_width,
135                                                 int64_t window_len,
136                                                 int64_t stride,
137                                                 Padding padding) {
138   if (padding == Padding::kValid) {
139     return window_util::StridedBound(unpadded_width, window_len, stride);
140   }
141   return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
142 }
143 
144 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,absl::Span<const std::pair<int64_t,int64_t>> padding)145 ReferenceUtil::ReduceWindow1DGeneric(
146     absl::Span<const float> operand, float init,
147     const std::function<float(float, float)>& reduce_func,
148     absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
149     absl::Span<const std::pair<int64_t, int64_t>> padding) {
150   CHECK_EQ(window.size(), 1);
151   CHECK_EQ(stride.size(), 1);
152   CHECK_EQ(padding.size(), 1);
153 
154   int64_t padded_width = padding[0].first + operand.size() + padding[0].second;
155   int64_t stride_amount = stride[0];
156   int64_t window_size = window[0];
157   int64_t result_size =
158       window_util::StridedBound(padded_width, window_size, stride_amount);
159   int64_t pad_low = padding[0].first;
160   auto result = std::make_unique<std::vector<float>>(result_size);
161 
162   // Do a full 1D reduce window.
163   for (int64_t i0 = 0; i0 < result_size; ++i0) {
164     int64_t i0_base = i0 * stride_amount - pad_low;
165     float val = init;
166     for (int64_t i0_win = 0; i0_win < window_size; ++i0_win) {
167       if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
168         val = reduce_func(val, operand[i0_base + i0_win]);
169       }
170     }
171     (*result)[i0] = val;
172   }
173   return result;
174 }
175 
176 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)177 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
178                                  absl::Span<const int64_t> window,
179                                  absl::Span<const int64_t> stride,
180                                  Padding padding) {
181   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
182   std::vector<int64_t> dim_lengths{static_cast<int64_t>(operand.size())};
183   return ReduceWindow1DGeneric(
184       operand, init, add_reduce, window, stride,
185       xla::MakePadding(dim_lengths, window, stride, padding));
186 }
187 
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)188 /* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
189     const Array3D<float>& operand, float init, absl::Span<const int64_t> window,
190     absl::Span<const int64_t> stride, Padding padding) {
191   std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
192   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
193 
194   std::vector<int64_t> window_counts(window.size(), 0);
195   std::vector<int64_t> pad_low(window.size(), 0);
196   for (int64_t i = 0; i < window.size(); ++i) {
197     window_counts[i] =
198         WindowCount(dim_lengths[i], window[i], stride[i], padding);
199     pad_low[i] = padding_both[i].first;
200   }
201   auto result = std::make_unique<Array3D<float>>(
202       window_counts[0], window_counts[1], window_counts[2]);
203 
204   for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
205     for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
206       for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
207         int64_t i0_base = i0 * stride[0] - pad_low[0];
208         int64_t i1_base = i1 * stride[1] - pad_low[1];
209         int64_t i2_base = i2 * stride[2] - pad_low[2];
210 
211         float val = init;
212         for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
213           for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
214             for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
215               if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
216                   i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
217                   i1_base + i1_win < operand.n2() &&
218                   i2_base + i2_win < operand.n3()) {
219                 val += operand(i0_base + i0_win, i1_base + i1_win,
220                                i2_base + i2_win);
221               }
222             }
223           }
224         }
225         (*result)(i0, i1, i2) = val;
226       }
227     }
228   }
229   return result;
230 }
231 
232 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)233 ReferenceUtil::ReduceWindow4DGeneric(
234     const Array4D<float>& operand, float init,
235     const std::function<float(float, float)>& reduce_func,
236     absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
237     Padding padding) {
238   std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
239                                    operand.n4()};
240   return ReduceWindow4DGeneric(
241       operand, init, reduce_func, window, stride,
242       xla::MakePadding(dim_lengths, window, stride, padding));
243 }
244 
245 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,absl::Span<const std::pair<int64_t,int64_t>> padding)246 ReferenceUtil::ReduceWindow4DGeneric(
247     const Array4D<float>& operand, float init,
248     const std::function<float(float, float)>& reduce_func,
249     absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
250     absl::Span<const std::pair<int64_t, int64_t>> padding) {
251   std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
252                                    operand.n4()};
253 
254   std::vector<int64_t> window_counts(window.size(), 0);
255   std::vector<int64_t> pad_low(window.size(), 0);
256   for (int64_t i = 0; i < window.size(); ++i) {
257     int64_t padded_width =
258         padding[i].first + dim_lengths[i] + padding[i].second;
259     window_counts[i] =
260         window_util::StridedBound(padded_width, window[i], stride[i]);
261     pad_low[i] = padding[i].first;
262   }
263   auto result = std::make_unique<Array4D<float>>(
264       window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
265   // Do a full 4D reduce window.
266   for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
267     for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
268       for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
269         for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
270           int64_t i0_base = i0 * stride[0] - pad_low[0];
271           int64_t i1_base = i1 * stride[1] - pad_low[1];
272           int64_t i2_base = i2 * stride[2] - pad_low[2];
273           int64_t i3_base = i3 * stride[3] - pad_low[3];
274 
275           float val = init;
276           for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
277             for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
278               for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
279                 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
280                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
281                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
282                       i0_base + i0_win < operand.n1() &&
283                       i1_base + i1_win < operand.n2() &&
284                       i2_base + i2_win < operand.n3() &&
285                       i3_base + i3_win < operand.n4()) {
286                     val = reduce_func(
287                         val, operand(i0_base + i0_win, i1_base + i1_win,
288                                      i2_base + i2_win, i3_base + i3_win));
289                   }
290                 }
291               }
292             }
293           }
294           (*result)(i0, i1, i2, i3) = val;
295         }
296       }
297     }
298   }
299   return result;
300 }
301 
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)302 /* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
303     const Array4D<float>& operand, float init, absl::Span<const int64_t> window,
304     absl::Span<const int64_t> stride, Padding padding) {
305   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
306   return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
307                                padding);
308 }
309 
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)310 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
311     const Array4D<float>& input, const Array4D<float>& mean,
312     const Array4D<float>& var, const Array4D<float>& scale,
313     const Array4D<float>& offset, float epsilon) {
314   auto normalized =
315       *MapArray4D(input, mean, [](float a, float b) { return a - b; });
316   normalized = *MapArray4D(normalized, var, [&](float a, float b) {
317     return a / std::sqrt(b + epsilon);
318   });
319   normalized =
320       *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
321   return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
322 }
323 
324 /* static  */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,bool same_padding)325 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
326                                         const Array4D<float>& source,
327                                         float init,
328                                         absl::Span<const int64_t> window,
329                                         absl::Span<const int64_t> stride,
330                                         bool same_padding) {
331   Padding padding = same_padding ? Padding::kSame : Padding::kValid;
332   auto result = std::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
333                                                  operand.n3(), operand.n4());
334   std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
335                                    operand.n4()};
336   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
337   // Fill the output, with the initial value.
338   result->Fill(init);
339 
340   std::vector<int64_t> window_counts(window.size(), 0);
341   std::vector<int64_t> pad_low(window.size(), 0);
342   for (int64_t i = 0; i < window.size(); ++i) {
343     window_counts[i] =
344         WindowCount(dim_lengths[i], window[i], stride[i], padding);
345     pad_low[i] = padding_both[i].first;
346   }
347   CHECK_EQ(window_counts[0], source.n1());
348   CHECK_EQ(window_counts[1], source.n2());
349   CHECK_EQ(window_counts[2], source.n3());
350   CHECK_EQ(window_counts[3], source.n4());
351 
352   // Do a full 4D select and Scatter.
353   for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
354     for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
355       for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
356         for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
357           // Now we are inside a window and need to find the max and the argmax.
358           int64_t i0_base = i0 * stride[0] - pad_low[0];
359           int64_t i1_base = i1 * stride[1] - pad_low[1];
360           int64_t i2_base = i2 * stride[2] - pad_low[2];
361           int64_t i3_base = i3 * stride[3] - pad_low[3];
362           int64_t scatter_0 = (i0_base >= 0) ? i0_base : 0;
363           int64_t scatter_1 = (i1_base >= 0) ? i1_base : 0;
364           int64_t scatter_2 = (i2_base >= 0) ? i2_base : 0;
365           int64_t scatter_3 = (i3_base >= 0) ? i3_base : 0;
366           float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
367           for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
368             for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
369               for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
370                 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
371                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
372                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
373                       i0_base + i0_win < operand.n1() &&
374                       i1_base + i1_win < operand.n2() &&
375                       i2_base + i2_win < operand.n3() &&
376                       i3_base + i3_win < operand.n4()) {
377                     float tmp = operand(i0_base + i0_win, i1_base + i1_win,
378                                         i2_base + i2_win, i3_base + i3_win);
379                     if (tmp > val) {
380                       val = tmp;
381                       scatter_0 = i0_base + i0_win;
382                       scatter_1 = i1_base + i1_win;
383                       scatter_2 = i2_base + i2_win;
384                       scatter_3 = i3_base + i3_win;
385                     }
386                   }
387                 }
388               }
389             }
390           }
391           (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
392               source(i0, i1, i2, i3);
393         }
394       }
395     }
396   }
397   return result;
398 }
399 
400 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64_t,int64_t> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)401 ReferenceUtil::ConvArray4DGeneralDimensions(
402     const Array4D<float>& lhs, const Array4D<float>& rhs,
403     std::pair<int64_t, int64_t> kernel_stride, Padding padding,
404     ConvolutionDimensionNumbers dimension_numbers) {
405   return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
406                                              {1, 1}, {1, 1},
407                                              std::move(dimension_numbers));
408 }
409 
410 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64_t,int64_t> kernel_stride,Padding padding,std::pair<int64_t,int64_t> lhs_dilation,std::pair<int64_t,int64_t> rhs_dilation,ConvolutionDimensionNumbers dnums)411 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
412     const Array4D<float>& lhs, const Array4D<float>& rhs,
413     std::pair<int64_t, int64_t> kernel_stride, Padding padding,
414     std::pair<int64_t, int64_t> lhs_dilation,
415     std::pair<int64_t, int64_t> rhs_dilation,
416     ConvolutionDimensionNumbers dnums) {
417   HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
418   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
419   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
420 
421   std::array<int64_t, 2> ordered_kernel_strides;
422   std::array<int64_t, 2> ordered_input_dimensions;
423   std::array<int64_t, 2> ordered_kernel_dimensions;
424   if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
425     ordered_kernel_strides[0] = kernel_stride.second;
426     ordered_kernel_strides[1] = kernel_stride.first;
427   } else {
428     ordered_kernel_strides[0] = kernel_stride.first;
429     ordered_kernel_strides[1] = kernel_stride.second;
430   }
431 
432   ordered_input_dimensions[0] =
433       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
434   ordered_input_dimensions[1] =
435       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
436   ordered_kernel_dimensions[0] =
437       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
438   ordered_kernel_dimensions[1] =
439       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
440 
441   std::vector<std::pair<int64_t, int64_t>> paddings =
442       MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
443                   ordered_kernel_strides, padding);
444   CHECK_EQ(paddings.size(), 2);
445 
446   Window window;
447 
448   WindowDimension dim;
449   dim.set_size(
450       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
451   dim.set_stride(kernel_stride.first);
452   dim.set_padding_low(paddings[0].first);
453   dim.set_padding_high(paddings[0].second);
454   dim.set_window_dilation(rhs_dilation.first);
455   dim.set_base_dilation(lhs_dilation.first);
456   *window.add_dimensions() = dim;
457 
458   WindowDimension dim2;
459   dim2.set_size(
460       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
461   dim2.set_stride(kernel_stride.second);
462   dim2.set_padding_low(paddings[1].first);
463   dim2.set_padding_high(paddings[1].second);
464   dim2.set_window_dilation(rhs_dilation.second);
465   dim2.set_base_dilation(lhs_dilation.second);
466   *window.add_dimensions() = dim2;
467 
468   const Shape shape =
469       ShapeInference::InferConvolveShape(
470           lhs_literal.shape(), rhs_literal.shape(),
471           /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
472           /*preferred_element_type=*/std::nullopt)
473           .value();
474 
475   HloInstruction* lhs_instruction =
476       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
477   HloInstruction* rhs_instruction =
478       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
479 
480   PrecisionConfig precision_config;
481   precision_config.mutable_operand_precision()->Resize(
482       /*new_size=*/2, PrecisionConfig::DEFAULT);
483   b.AddInstruction(HloInstruction::CreateConvolve(
484       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
485       /*batch_group_count=*/1, window, dnums, precision_config));
486   HloModuleConfig config;
487   HloModule module("ReferenceUtil", config);
488   auto computation = module.AddEntryComputation(b.Build());
489 
490   HloEvaluator evaluator;
491   Literal result_literal = evaluator.Evaluate(*computation, {}).value();
492 
493   CHECK_EQ(result_literal.shape().rank(), 4);
494   auto result =
495       std::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
496                                        result_literal.shape().dimensions(1),
497                                        result_literal.shape().dimensions(2),
498                                        result_literal.shape().dimensions(3));
499 
500   result->Each([&](absl::Span<const int64_t> indices, float* value) {
501     *value = result_literal.Get<float>(indices);
502   });
503 
504   return result;
505 }
506 
507 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)508 ReferenceUtil::ReduceToColArray2D(
509     const Array2D<float>& matrix, float init,
510     const std::function<float(float, float)>& reduce_function) {
511   int64_t rows = matrix.height();
512   int64_t cols = matrix.width();
513   auto result = std::make_unique<std::vector<float>>();
514   for (int64_t i = 0; i < rows; ++i) {
515     float acc = init;
516     for (int64_t j = 0; j < cols; ++j) {
517       acc = reduce_function(acc, matrix(i, j));
518     }
519     result->push_back(acc);
520   }
521   return result;
522 }
523 
524 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)525 ReferenceUtil::ReduceToRowArray2D(
526     const Array2D<float>& matrix, float init,
527     const std::function<float(float, float)>& reduce_function) {
528   int64_t rows = matrix.height();
529   int64_t cols = matrix.width();
530   auto result = std::make_unique<std::vector<float>>();
531   for (int64_t i = 0; i < cols; ++i) {
532     float acc = init;
533     for (int64_t j = 0; j < rows; ++j) {
534       acc = reduce_function(acc, matrix(j, i));
535     }
536     result->push_back(acc);
537   }
538   return result;
539 }
540 
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64_t> dims,const std::function<float (float,float)> & reduce_function)541 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
542     const Array4D<float>& array, float init, absl::Span<const int64_t> dims,
543     const std::function<float(float, float)>& reduce_function) {
544   std::vector<float> result;
545   CHECK_EQ(dims.size(), 3);
546   const absl::flat_hash_set<int64_t> dim_set(dims.begin(), dims.end());
547   CHECK_EQ(dim_set.size(), 3);
548   for (int64_t a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
549        ++a0) {
550     for (int64_t a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
551          ++a1) {
552       for (int64_t a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
553            ++a2) {
554         for (int64_t a3 = 0;
555              a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) {
556           float accumulator = init;
557           for (int64_t i0 = 0;
558                i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
559             for (int64_t i1 = 0;
560                  i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
561               for (int64_t i2 = 0;
562                    i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
563                 for (int64_t i3 = 0;
564                      i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
565                      ++i3) {
566                   // Handle zero-sized arrays.
567                   if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
568                       array.n4() > 0) {
569                     accumulator = reduce_function(
570                         accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
571                   }
572                 }
573               }
574             }
575           }
576           result.push_back(accumulator);
577         }
578       }
579     }
580   }
581   return result;
582 }
583 
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64_t> & bounds,int64_t broadcast_from_dim)584 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
585     const std::vector<float>& array, const std::vector<int64_t>& bounds,
586     int64_t broadcast_from_dim) {
587   auto result = std::make_unique<Array4D<float>>(bounds[0], bounds[1],
588                                                  bounds[2], bounds[3]);
589   for (int64_t i = 0; i < result->n1(); ++i) {
590     for (int64_t j = 0; j < result->n2(); ++j) {
591       for (int64_t k = 0; k < result->n3(); ++k) {
592         for (int64_t l = 0; l < result->n4(); ++l) {
593           switch (broadcast_from_dim) {
594             case 0:
595               (*result)(i, j, k, l) = array[i];
596               break;
597             case 1:
598               (*result)(i, j, k, l) = array[j];
599               break;
600             case 2:
601               (*result)(i, j, k, l) = array[k];
602               break;
603             case 3:
604               (*result)(i, j, k, l) = array[l];
605               break;
606             default:
607               break;
608           }
609         }
610       }
611     }
612   }
613   return result;
614 }
615 
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64_t> dims,const std::function<float (float,float)> & reduce_function)616 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
617     const Array3D<float>& array, float init, absl::Span<const int64_t> dims,
618     const std::function<float(float, float)>& reduce_function) {
619   CHECK_EQ(dims.size(), 1);
620   int64_t rows = dims[0] == 0 ? array.n2() : array.n1();
621   int64_t cols = dims[0] == 2 ? array.n2() : array.n3();
622   auto result = std::make_unique<Array2D<float>>(rows, cols);
623   result->Fill(init);
624   for (int i0 = 0; i0 < array.n1(); ++i0) {
625     for (int i1 = 0; i1 < array.n2(); ++i1) {
626       for (int i2 = 0; i2 < array.n3(); ++i2) {
627         int64_t row = dims[0] == 0 ? i1 : i0;
628         int64_t col = dims[0] == 2 ? i1 : i2;
629         (*result)(row, col) =
630             reduce_function((*result)(row, col), array(i0, i1, i2));
631       }
632     }
633   }
634   return result;
635 }
636 
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)637 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
638     const Array2D<float>& matrix,
639     const std::function<float(float)>& map_function) {
640   int64_t rows = matrix.height();
641   int64_t cols = matrix.width();
642   auto result = std::make_unique<Array2D<float>>(rows, cols);
643   for (int64_t i = 0; i < rows; ++i) {
644     for (int64_t j = 0; j < cols; ++j) {
645       (*result)(i, j) = map_function(matrix(i, j));
646     }
647   }
648   return result;
649 }
650 
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)651 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
652     const Array2D<float>& lhs, const Array2D<float>& rhs,
653     const std::function<float(float, float)>& map_function) {
654   CHECK_EQ(lhs.height(), rhs.height());
655   CHECK_EQ(lhs.width(), rhs.width());
656   int64_t rows = lhs.height();
657   int64_t cols = rhs.width();
658   auto result = std::make_unique<Array2D<float>>(rows, cols);
659   for (int64_t i = 0; i < rows; ++i) {
660     for (int64_t j = 0; j < cols; ++j) {
661       (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
662     }
663   }
664   return result;
665 }
666 
MapArray3D(const Array3D<float> & array,const std::function<float (float)> & map_function)667 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
668     const Array3D<float>& array,
669     const std::function<float(float)>& map_function) {
670   int64_t n1 = array.n1();
671   int64_t n2 = array.n2();
672   int64_t n3 = array.n3();
673   auto result = std::make_unique<Array3D<float>>(n1, n2, n3);
674   for (int64_t i = 0; i < n1; ++i) {
675     for (int64_t j = 0; j < n2; ++j) {
676       for (int64_t k = 0; k < n3; ++k) {
677         (*result)(i, j, k) = map_function(array(i, j, k));
678       }
679     }
680   }
681   return result;
682 }
683 
MapArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,const std::function<float (float,float)> & map_function)684 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
685     const Array3D<float>& lhs, const Array3D<float>& rhs,
686     const std::function<float(float, float)>& map_function) {
687   CHECK_EQ(lhs.n1(), rhs.n1());
688   CHECK_EQ(lhs.n2(), rhs.n2());
689   CHECK_EQ(lhs.n3(), rhs.n3());
690   int64_t n1 = lhs.n1();
691   int64_t n2 = rhs.n2();
692   int64_t n3 = rhs.n3();
693   auto result = std::make_unique<Array3D<float>>(n1, n2, n3);
694   for (int64_t i = 0; i < n1; ++i) {
695     for (int64_t j = 0; j < n2; ++j) {
696       for (int64_t k = 0; k < n3; ++k) {
697         (*result)(i, j, k) = map_function(lhs(i, j, k), rhs(i, j, k));
698       }
699     }
700   }
701   return result;
702 }
703 
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64_t,int64_t)> & map_function)704 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
705     const Array2D<float>& matrix,
706     const std::function<float(float, int64_t, int64_t)>& map_function) {
707   int64_t rows = matrix.height();
708   int64_t cols = matrix.width();
709   auto result = std::make_unique<Array2D<float>>(rows, cols);
710   for (int64_t i = 0; i < rows; ++i) {
711     for (int64_t j = 0; j < cols; ++j) {
712       (*result)(i, j) = map_function(matrix(i, j), i, j);
713     }
714   }
715   return result;
716 }
717 
718 }  // namespace xla
719