1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/reference_util.h"
17
18 #include <array>
19 #include <memory>
20 #include <utility>
21
22 #include "absl/container/flat_hash_set.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
Array2DF32ToF64(const Array2D<float> & input)35 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
36 const Array2D<float>& input) {
37 auto result =
38 std::make_unique<Array2D<double>>(input.height(), input.width());
39 for (int64_t rowno = 0; rowno < input.height(); ++rowno) {
40 for (int64_t colno = 0; colno < input.height(); ++colno) {
41 (*result)(rowno, colno) = input(rowno, colno);
42 }
43 }
44 return result;
45 }
46
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding)47 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
48 const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
49 Padding padding) {
50 return ConvArray3DGeneralDimensionsDilated(
51 lhs, rhs, kernel_stride, padding, 1, 1,
52 XlaBuilder::CreateDefaultConvDimensionNumbers(1));
53 }
54
55 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding,int64_t lhs_dilation,int64_t rhs_dilation,const ConvolutionDimensionNumbers & dnums)56 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
57 const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
58 Padding padding, int64_t lhs_dilation, int64_t rhs_dilation,
59 const ConvolutionDimensionNumbers& dnums) {
60 CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
61 CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
62 CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
63 // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
64 // array by adding a fourth dummy dimension of size 1 without stride, padding
65 // and dilation.
66 Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
67 a4dlhs.Each([&](absl::Span<const int64_t> indices, float* value_ptr) {
68 CHECK_EQ(indices[3], 0);
69 *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
70 });
71 Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
72 a4drhs.Each([&](absl::Span<const int64_t> indices, float* value_ptr) {
73 CHECK_EQ(indices[3], 0);
74 *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
75 });
76 // Add a second dummy spatial dimensions.
77 ConvolutionDimensionNumbers dnums2d = dnums;
78 dnums2d.add_input_spatial_dimensions(3);
79 dnums2d.add_kernel_spatial_dimensions(3);
80 dnums2d.add_output_spatial_dimensions(3);
81 std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
82 a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
83 {rhs_dilation, 1}, dnums2d);
84
85 auto convr3 = std::make_unique<Array3D<float>>(
86 convr4->planes(), convr4->depth(), convr4->height());
87 convr4->Each([&](absl::Span<const int64_t> indices, float* value_ptr) {
88 CHECK_EQ(indices[3], 0);
89 convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
90 });
91 return convr3;
92 }
93
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64_t,int64_t> kernel_stride,Padding padding)94 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
95 const Array4D<float>& lhs, const Array4D<float>& rhs,
96 std::pair<int64_t, int64_t> kernel_stride, Padding padding) {
97 return ConvArray4DGeneralDimensions(
98 lhs, rhs, kernel_stride, padding,
99 XlaBuilder::CreateDefaultConvDimensionNumbers());
100 }
101
102 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64_t,int64_t> kernel_stride,Padding padding)103 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
104 const Array4D<float>& depthwise_weights,
105 const Array4D<float>& pointwise_weights,
106 std::pair<int64_t, int64_t> kernel_stride,
107 Padding padding) {
108 const int64_t depth_multiplier = depthwise_weights.planes();
109 CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
110
111 // Combine the two weights by reducing the depth_multiplier, so that we can
112 // apply a single convolution on the combined weights.
113 Array4D<float> weights(pointwise_weights.planes(), input.depth(),
114 depthwise_weights.height(), depthwise_weights.width());
115 for (int64_t kx = 0; kx < depthwise_weights.width(); ++kx) {
116 for (int64_t ky = 0; ky < depthwise_weights.height(); ++ky) {
117 for (int64_t kz = 0; kz < input.depth(); ++kz) {
118 for (int64_t out = 0; out < pointwise_weights.planes(); ++out) {
119 float weight = 0.0;
120 for (int64_t depth = 0; depth < depth_multiplier; ++depth) {
121 weight +=
122 depthwise_weights(depth, kz, ky, kx) *
123 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
124 }
125 weights(out, kz, ky, kx) = weight;
126 }
127 }
128 }
129 }
130
131 return ConvArray4D(input, weights, kernel_stride, padding);
132 }
133
WindowCount(int64_t unpadded_width,int64_t window_len,int64_t stride,Padding padding)134 /* static */ int64_t ReferenceUtil::WindowCount(int64_t unpadded_width,
135 int64_t window_len,
136 int64_t stride,
137 Padding padding) {
138 if (padding == Padding::kValid) {
139 return window_util::StridedBound(unpadded_width, window_len, stride);
140 }
141 return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
142 }
143
144 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,absl::Span<const std::pair<int64_t,int64_t>> padding)145 ReferenceUtil::ReduceWindow1DGeneric(
146 absl::Span<const float> operand, float init,
147 const std::function<float(float, float)>& reduce_func,
148 absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
149 absl::Span<const std::pair<int64_t, int64_t>> padding) {
150 CHECK_EQ(window.size(), 1);
151 CHECK_EQ(stride.size(), 1);
152 CHECK_EQ(padding.size(), 1);
153
154 int64_t padded_width = padding[0].first + operand.size() + padding[0].second;
155 int64_t stride_amount = stride[0];
156 int64_t window_size = window[0];
157 int64_t result_size =
158 window_util::StridedBound(padded_width, window_size, stride_amount);
159 int64_t pad_low = padding[0].first;
160 auto result = std::make_unique<std::vector<float>>(result_size);
161
162 // Do a full 1D reduce window.
163 for (int64_t i0 = 0; i0 < result_size; ++i0) {
164 int64_t i0_base = i0 * stride_amount - pad_low;
165 float val = init;
166 for (int64_t i0_win = 0; i0_win < window_size; ++i0_win) {
167 if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
168 val = reduce_func(val, operand[i0_base + i0_win]);
169 }
170 }
171 (*result)[i0] = val;
172 }
173 return result;
174 }
175
176 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)177 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
178 absl::Span<const int64_t> window,
179 absl::Span<const int64_t> stride,
180 Padding padding) {
181 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
182 std::vector<int64_t> dim_lengths{static_cast<int64_t>(operand.size())};
183 return ReduceWindow1DGeneric(
184 operand, init, add_reduce, window, stride,
185 xla::MakePadding(dim_lengths, window, stride, padding));
186 }
187
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)188 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
189 const Array3D<float>& operand, float init, absl::Span<const int64_t> window,
190 absl::Span<const int64_t> stride, Padding padding) {
191 std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
192 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
193
194 std::vector<int64_t> window_counts(window.size(), 0);
195 std::vector<int64_t> pad_low(window.size(), 0);
196 for (int64_t i = 0; i < window.size(); ++i) {
197 window_counts[i] =
198 WindowCount(dim_lengths[i], window[i], stride[i], padding);
199 pad_low[i] = padding_both[i].first;
200 }
201 auto result = std::make_unique<Array3D<float>>(
202 window_counts[0], window_counts[1], window_counts[2]);
203
204 for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
205 for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
206 for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
207 int64_t i0_base = i0 * stride[0] - pad_low[0];
208 int64_t i1_base = i1 * stride[1] - pad_low[1];
209 int64_t i2_base = i2 * stride[2] - pad_low[2];
210
211 float val = init;
212 for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
213 for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
214 for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
215 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
216 i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
217 i1_base + i1_win < operand.n2() &&
218 i2_base + i2_win < operand.n3()) {
219 val += operand(i0_base + i0_win, i1_base + i1_win,
220 i2_base + i2_win);
221 }
222 }
223 }
224 }
225 (*result)(i0, i1, i2) = val;
226 }
227 }
228 }
229 return result;
230 }
231
232 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)233 ReferenceUtil::ReduceWindow4DGeneric(
234 const Array4D<float>& operand, float init,
235 const std::function<float(float, float)>& reduce_func,
236 absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
237 Padding padding) {
238 std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
239 operand.n4()};
240 return ReduceWindow4DGeneric(
241 operand, init, reduce_func, window, stride,
242 xla::MakePadding(dim_lengths, window, stride, padding));
243 }
244
245 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,absl::Span<const std::pair<int64_t,int64_t>> padding)246 ReferenceUtil::ReduceWindow4DGeneric(
247 const Array4D<float>& operand, float init,
248 const std::function<float(float, float)>& reduce_func,
249 absl::Span<const int64_t> window, absl::Span<const int64_t> stride,
250 absl::Span<const std::pair<int64_t, int64_t>> padding) {
251 std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
252 operand.n4()};
253
254 std::vector<int64_t> window_counts(window.size(), 0);
255 std::vector<int64_t> pad_low(window.size(), 0);
256 for (int64_t i = 0; i < window.size(); ++i) {
257 int64_t padded_width =
258 padding[i].first + dim_lengths[i] + padding[i].second;
259 window_counts[i] =
260 window_util::StridedBound(padded_width, window[i], stride[i]);
261 pad_low[i] = padding[i].first;
262 }
263 auto result = std::make_unique<Array4D<float>>(
264 window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
265 // Do a full 4D reduce window.
266 for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
267 for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
268 for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
269 for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
270 int64_t i0_base = i0 * stride[0] - pad_low[0];
271 int64_t i1_base = i1 * stride[1] - pad_low[1];
272 int64_t i2_base = i2 * stride[2] - pad_low[2];
273 int64_t i3_base = i3 * stride[3] - pad_low[3];
274
275 float val = init;
276 for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
277 for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
278 for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
279 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
280 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
281 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
282 i0_base + i0_win < operand.n1() &&
283 i1_base + i1_win < operand.n2() &&
284 i2_base + i2_win < operand.n3() &&
285 i3_base + i3_win < operand.n4()) {
286 val = reduce_func(
287 val, operand(i0_base + i0_win, i1_base + i1_win,
288 i2_base + i2_win, i3_base + i3_win));
289 }
290 }
291 }
292 }
293 }
294 (*result)(i0, i1, i2, i3) = val;
295 }
296 }
297 }
298 }
299 return result;
300 }
301
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,Padding padding)302 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
303 const Array4D<float>& operand, float init, absl::Span<const int64_t> window,
304 absl::Span<const int64_t> stride, Padding padding) {
305 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
306 return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
307 padding);
308 }
309
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)310 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
311 const Array4D<float>& input, const Array4D<float>& mean,
312 const Array4D<float>& var, const Array4D<float>& scale,
313 const Array4D<float>& offset, float epsilon) {
314 auto normalized =
315 *MapArray4D(input, mean, [](float a, float b) { return a - b; });
316 normalized = *MapArray4D(normalized, var, [&](float a, float b) {
317 return a / std::sqrt(b + epsilon);
318 });
319 normalized =
320 *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
321 return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
322 }
323
324 /* static */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64_t> window,absl::Span<const int64_t> stride,bool same_padding)325 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
326 const Array4D<float>& source,
327 float init,
328 absl::Span<const int64_t> window,
329 absl::Span<const int64_t> stride,
330 bool same_padding) {
331 Padding padding = same_padding ? Padding::kSame : Padding::kValid;
332 auto result = std::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
333 operand.n3(), operand.n4());
334 std::vector<int64_t> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
335 operand.n4()};
336 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
337 // Fill the output, with the initial value.
338 result->Fill(init);
339
340 std::vector<int64_t> window_counts(window.size(), 0);
341 std::vector<int64_t> pad_low(window.size(), 0);
342 for (int64_t i = 0; i < window.size(); ++i) {
343 window_counts[i] =
344 WindowCount(dim_lengths[i], window[i], stride[i], padding);
345 pad_low[i] = padding_both[i].first;
346 }
347 CHECK_EQ(window_counts[0], source.n1());
348 CHECK_EQ(window_counts[1], source.n2());
349 CHECK_EQ(window_counts[2], source.n3());
350 CHECK_EQ(window_counts[3], source.n4());
351
352 // Do a full 4D select and Scatter.
353 for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
354 for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
355 for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
356 for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
357 // Now we are inside a window and need to find the max and the argmax.
358 int64_t i0_base = i0 * stride[0] - pad_low[0];
359 int64_t i1_base = i1 * stride[1] - pad_low[1];
360 int64_t i2_base = i2 * stride[2] - pad_low[2];
361 int64_t i3_base = i3 * stride[3] - pad_low[3];
362 int64_t scatter_0 = (i0_base >= 0) ? i0_base : 0;
363 int64_t scatter_1 = (i1_base >= 0) ? i1_base : 0;
364 int64_t scatter_2 = (i2_base >= 0) ? i2_base : 0;
365 int64_t scatter_3 = (i3_base >= 0) ? i3_base : 0;
366 float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
367 for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
368 for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
369 for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
370 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
371 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
372 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
373 i0_base + i0_win < operand.n1() &&
374 i1_base + i1_win < operand.n2() &&
375 i2_base + i2_win < operand.n3() &&
376 i3_base + i3_win < operand.n4()) {
377 float tmp = operand(i0_base + i0_win, i1_base + i1_win,
378 i2_base + i2_win, i3_base + i3_win);
379 if (tmp > val) {
380 val = tmp;
381 scatter_0 = i0_base + i0_win;
382 scatter_1 = i1_base + i1_win;
383 scatter_2 = i2_base + i2_win;
384 scatter_3 = i3_base + i3_win;
385 }
386 }
387 }
388 }
389 }
390 }
391 (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
392 source(i0, i1, i2, i3);
393 }
394 }
395 }
396 }
397 return result;
398 }
399
400 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64_t,int64_t> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)401 ReferenceUtil::ConvArray4DGeneralDimensions(
402 const Array4D<float>& lhs, const Array4D<float>& rhs,
403 std::pair<int64_t, int64_t> kernel_stride, Padding padding,
404 ConvolutionDimensionNumbers dimension_numbers) {
405 return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
406 {1, 1}, {1, 1},
407 std::move(dimension_numbers));
408 }
409
410 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64_t,int64_t> kernel_stride,Padding padding,std::pair<int64_t,int64_t> lhs_dilation,std::pair<int64_t,int64_t> rhs_dilation,ConvolutionDimensionNumbers dnums)411 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
412 const Array4D<float>& lhs, const Array4D<float>& rhs,
413 std::pair<int64_t, int64_t> kernel_stride, Padding padding,
414 std::pair<int64_t, int64_t> lhs_dilation,
415 std::pair<int64_t, int64_t> rhs_dilation,
416 ConvolutionDimensionNumbers dnums) {
417 HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
418 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
419 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
420
421 std::array<int64_t, 2> ordered_kernel_strides;
422 std::array<int64_t, 2> ordered_input_dimensions;
423 std::array<int64_t, 2> ordered_kernel_dimensions;
424 if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
425 ordered_kernel_strides[0] = kernel_stride.second;
426 ordered_kernel_strides[1] = kernel_stride.first;
427 } else {
428 ordered_kernel_strides[0] = kernel_stride.first;
429 ordered_kernel_strides[1] = kernel_stride.second;
430 }
431
432 ordered_input_dimensions[0] =
433 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
434 ordered_input_dimensions[1] =
435 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
436 ordered_kernel_dimensions[0] =
437 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
438 ordered_kernel_dimensions[1] =
439 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
440
441 std::vector<std::pair<int64_t, int64_t>> paddings =
442 MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
443 ordered_kernel_strides, padding);
444 CHECK_EQ(paddings.size(), 2);
445
446 Window window;
447
448 WindowDimension dim;
449 dim.set_size(
450 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
451 dim.set_stride(kernel_stride.first);
452 dim.set_padding_low(paddings[0].first);
453 dim.set_padding_high(paddings[0].second);
454 dim.set_window_dilation(rhs_dilation.first);
455 dim.set_base_dilation(lhs_dilation.first);
456 *window.add_dimensions() = dim;
457
458 WindowDimension dim2;
459 dim2.set_size(
460 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
461 dim2.set_stride(kernel_stride.second);
462 dim2.set_padding_low(paddings[1].first);
463 dim2.set_padding_high(paddings[1].second);
464 dim2.set_window_dilation(rhs_dilation.second);
465 dim2.set_base_dilation(lhs_dilation.second);
466 *window.add_dimensions() = dim2;
467
468 const Shape shape =
469 ShapeInference::InferConvolveShape(
470 lhs_literal.shape(), rhs_literal.shape(),
471 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
472 /*preferred_element_type=*/std::nullopt)
473 .value();
474
475 HloInstruction* lhs_instruction =
476 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
477 HloInstruction* rhs_instruction =
478 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
479
480 PrecisionConfig precision_config;
481 precision_config.mutable_operand_precision()->Resize(
482 /*new_size=*/2, PrecisionConfig::DEFAULT);
483 b.AddInstruction(HloInstruction::CreateConvolve(
484 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
485 /*batch_group_count=*/1, window, dnums, precision_config));
486 HloModuleConfig config;
487 HloModule module("ReferenceUtil", config);
488 auto computation = module.AddEntryComputation(b.Build());
489
490 HloEvaluator evaluator;
491 Literal result_literal = evaluator.Evaluate(*computation, {}).value();
492
493 CHECK_EQ(result_literal.shape().rank(), 4);
494 auto result =
495 std::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
496 result_literal.shape().dimensions(1),
497 result_literal.shape().dimensions(2),
498 result_literal.shape().dimensions(3));
499
500 result->Each([&](absl::Span<const int64_t> indices, float* value) {
501 *value = result_literal.Get<float>(indices);
502 });
503
504 return result;
505 }
506
507 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)508 ReferenceUtil::ReduceToColArray2D(
509 const Array2D<float>& matrix, float init,
510 const std::function<float(float, float)>& reduce_function) {
511 int64_t rows = matrix.height();
512 int64_t cols = matrix.width();
513 auto result = std::make_unique<std::vector<float>>();
514 for (int64_t i = 0; i < rows; ++i) {
515 float acc = init;
516 for (int64_t j = 0; j < cols; ++j) {
517 acc = reduce_function(acc, matrix(i, j));
518 }
519 result->push_back(acc);
520 }
521 return result;
522 }
523
524 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)525 ReferenceUtil::ReduceToRowArray2D(
526 const Array2D<float>& matrix, float init,
527 const std::function<float(float, float)>& reduce_function) {
528 int64_t rows = matrix.height();
529 int64_t cols = matrix.width();
530 auto result = std::make_unique<std::vector<float>>();
531 for (int64_t i = 0; i < cols; ++i) {
532 float acc = init;
533 for (int64_t j = 0; j < rows; ++j) {
534 acc = reduce_function(acc, matrix(j, i));
535 }
536 result->push_back(acc);
537 }
538 return result;
539 }
540
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64_t> dims,const std::function<float (float,float)> & reduce_function)541 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
542 const Array4D<float>& array, float init, absl::Span<const int64_t> dims,
543 const std::function<float(float, float)>& reduce_function) {
544 std::vector<float> result;
545 CHECK_EQ(dims.size(), 3);
546 const absl::flat_hash_set<int64_t> dim_set(dims.begin(), dims.end());
547 CHECK_EQ(dim_set.size(), 3);
548 for (int64_t a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
549 ++a0) {
550 for (int64_t a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
551 ++a1) {
552 for (int64_t a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
553 ++a2) {
554 for (int64_t a3 = 0;
555 a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) {
556 float accumulator = init;
557 for (int64_t i0 = 0;
558 i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
559 for (int64_t i1 = 0;
560 i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
561 for (int64_t i2 = 0;
562 i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
563 for (int64_t i3 = 0;
564 i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
565 ++i3) {
566 // Handle zero-sized arrays.
567 if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
568 array.n4() > 0) {
569 accumulator = reduce_function(
570 accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
571 }
572 }
573 }
574 }
575 }
576 result.push_back(accumulator);
577 }
578 }
579 }
580 }
581 return result;
582 }
583
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64_t> & bounds,int64_t broadcast_from_dim)584 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
585 const std::vector<float>& array, const std::vector<int64_t>& bounds,
586 int64_t broadcast_from_dim) {
587 auto result = std::make_unique<Array4D<float>>(bounds[0], bounds[1],
588 bounds[2], bounds[3]);
589 for (int64_t i = 0; i < result->n1(); ++i) {
590 for (int64_t j = 0; j < result->n2(); ++j) {
591 for (int64_t k = 0; k < result->n3(); ++k) {
592 for (int64_t l = 0; l < result->n4(); ++l) {
593 switch (broadcast_from_dim) {
594 case 0:
595 (*result)(i, j, k, l) = array[i];
596 break;
597 case 1:
598 (*result)(i, j, k, l) = array[j];
599 break;
600 case 2:
601 (*result)(i, j, k, l) = array[k];
602 break;
603 case 3:
604 (*result)(i, j, k, l) = array[l];
605 break;
606 default:
607 break;
608 }
609 }
610 }
611 }
612 }
613 return result;
614 }
615
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64_t> dims,const std::function<float (float,float)> & reduce_function)616 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
617 const Array3D<float>& array, float init, absl::Span<const int64_t> dims,
618 const std::function<float(float, float)>& reduce_function) {
619 CHECK_EQ(dims.size(), 1);
620 int64_t rows = dims[0] == 0 ? array.n2() : array.n1();
621 int64_t cols = dims[0] == 2 ? array.n2() : array.n3();
622 auto result = std::make_unique<Array2D<float>>(rows, cols);
623 result->Fill(init);
624 for (int i0 = 0; i0 < array.n1(); ++i0) {
625 for (int i1 = 0; i1 < array.n2(); ++i1) {
626 for (int i2 = 0; i2 < array.n3(); ++i2) {
627 int64_t row = dims[0] == 0 ? i1 : i0;
628 int64_t col = dims[0] == 2 ? i1 : i2;
629 (*result)(row, col) =
630 reduce_function((*result)(row, col), array(i0, i1, i2));
631 }
632 }
633 }
634 return result;
635 }
636
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)637 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
638 const Array2D<float>& matrix,
639 const std::function<float(float)>& map_function) {
640 int64_t rows = matrix.height();
641 int64_t cols = matrix.width();
642 auto result = std::make_unique<Array2D<float>>(rows, cols);
643 for (int64_t i = 0; i < rows; ++i) {
644 for (int64_t j = 0; j < cols; ++j) {
645 (*result)(i, j) = map_function(matrix(i, j));
646 }
647 }
648 return result;
649 }
650
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)651 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
652 const Array2D<float>& lhs, const Array2D<float>& rhs,
653 const std::function<float(float, float)>& map_function) {
654 CHECK_EQ(lhs.height(), rhs.height());
655 CHECK_EQ(lhs.width(), rhs.width());
656 int64_t rows = lhs.height();
657 int64_t cols = rhs.width();
658 auto result = std::make_unique<Array2D<float>>(rows, cols);
659 for (int64_t i = 0; i < rows; ++i) {
660 for (int64_t j = 0; j < cols; ++j) {
661 (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
662 }
663 }
664 return result;
665 }
666
MapArray3D(const Array3D<float> & array,const std::function<float (float)> & map_function)667 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
668 const Array3D<float>& array,
669 const std::function<float(float)>& map_function) {
670 int64_t n1 = array.n1();
671 int64_t n2 = array.n2();
672 int64_t n3 = array.n3();
673 auto result = std::make_unique<Array3D<float>>(n1, n2, n3);
674 for (int64_t i = 0; i < n1; ++i) {
675 for (int64_t j = 0; j < n2; ++j) {
676 for (int64_t k = 0; k < n3; ++k) {
677 (*result)(i, j, k) = map_function(array(i, j, k));
678 }
679 }
680 }
681 return result;
682 }
683
MapArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,const std::function<float (float,float)> & map_function)684 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::MapArray3D(
685 const Array3D<float>& lhs, const Array3D<float>& rhs,
686 const std::function<float(float, float)>& map_function) {
687 CHECK_EQ(lhs.n1(), rhs.n1());
688 CHECK_EQ(lhs.n2(), rhs.n2());
689 CHECK_EQ(lhs.n3(), rhs.n3());
690 int64_t n1 = lhs.n1();
691 int64_t n2 = rhs.n2();
692 int64_t n3 = rhs.n3();
693 auto result = std::make_unique<Array3D<float>>(n1, n2, n3);
694 for (int64_t i = 0; i < n1; ++i) {
695 for (int64_t j = 0; j < n2; ++j) {
696 for (int64_t k = 0; k < n3; ++k) {
697 (*result)(i, j, k) = map_function(lhs(i, j, k), rhs(i, j, k));
698 }
699 }
700 }
701 return result;
702 }
703
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64_t,int64_t)> & map_function)704 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
705 const Array2D<float>& matrix,
706 const std::function<float(float, int64_t, int64_t)>& map_function) {
707 int64_t rows = matrix.height();
708 int64_t cols = matrix.width();
709 auto result = std::make_unique<Array2D<float>>(rows, cols);
710 for (int64_t i = 0; i < rows; ++i) {
711 for (int64_t j = 0; j < cols; ++j) {
712 (*result)(i, j) = map_function(matrix(i, j), i, j);
713 }
714 }
715 return result;
716 }
717
718 } // namespace xla
719