xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_grad_input_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #ifndef TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
19 #define TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
20 
21 #define USE_EIGEN_TENSOR
22 #define EIGEN_USE_THREADS
23 
24 #include <algorithm>
25 #include <limits>
26 #include <vector>
27 
28 #include "absl/base/dynamic_annotations.h"
29 #include "tensorflow/core/framework/bounds_check.h"
30 #include "tensorflow/core/framework/kernel_shape_util.h"
31 #include "tensorflow/core/framework/numeric_op.h"
32 #include "tensorflow/core/framework/op_kernel.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_slice.h"
37 #include "tensorflow/core/kernels/conv_2d.h"
38 #include "tensorflow/core/kernels/conv_grad_ops.h"
39 #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
40 #include "tensorflow/core/kernels/fill_functor.h"
41 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
42 #include "tensorflow/core/kernels/xsmm_conv2d.h"
43 #endif
44 #include "tensorflow/core/lib/core/errors.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/platform/logging.h"
47 #include "tensorflow/core/platform/macros.h"
48 #include "tensorflow/core/util/padding.h"
49 #include "tensorflow/core/util/tensor_format.h"
50 #include "tensorflow/core/util/use_cudnn.h"
51 #include "tensorflow/core/util/work_sharder.h"
52 
53 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
54 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
55 #endif
56 
57 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
58 #include "tensorflow/core/kernels/conv_ops_gpu.h"
59 #include "tensorflow/core/platform/stream_executor.h"
60 #include "tensorflow/core/util/proto/proto_utils.h"
61 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
62 #if GOOGLE_CUDA
63 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
64 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
65 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
66 #endif  // GOOGLE_CUDA
67 
68 namespace tensorflow {
69 
70 typedef Eigen::ThreadPoolDevice CPUDevice;
71 typedef Eigen::GpuDevice GPUDevice;
72 
73 // Returns in 'im_data' (assumes to be zero-initialized) image patch in storage
74 // order (height, width, depth), constructed from patches in 'col_data', which
75 // is required to be in storage order (out_height * out_width, filter_height,
76 // filter_width, in_depth).  Implementation by Yangqing Jia (jiayq).
77 template <typename T>
Col2im(const T * col_data,const int depth,const int height,const int width,const int filter_h,const int filter_w,const int pad_t,const int pad_l,const int pad_b,const int pad_r,const int stride_h,const int stride_w,T * __restrict im_data)78 void Col2im(const T* col_data, const int depth, const int height,
79             const int width, const int filter_h, const int filter_w,
80             const int pad_t, const int pad_l, const int pad_b, const int pad_r,
81             const int stride_h, const int stride_w, T* __restrict im_data) {
82   int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
83   int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
84   int h_pad = -pad_t;
85   for (int h = 0; h < height_col; ++h) {
86     int w_pad = -pad_l;
87     for (int w = 0; w < width_col; ++w) {
88       T* im_patch_data = im_data + (h_pad * width + w_pad) * depth;
89       for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
90         for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
91           if (ih >= 0 && ih < height && iw >= 0 && iw < width) {
92             for (int i = 0; i < depth; ++i) {
93               im_patch_data[i] += col_data[i];
94             }
95           }
96           im_patch_data += depth;
97           col_data += depth;
98         }
99         // Jump over remaining number of depth.
100         im_patch_data += depth * (width - filter_w);
101       }
102       w_pad += stride_w;
103     }
104     h_pad += stride_h;
105   }
106 }
107 
108 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU
109 // and GPU (for int32 only).
110 template <typename Device, typename T>
111 struct LaunchConv2DBackpropInputOpImpl {
operatorLaunchConv2DBackpropInputOpImpl112   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
113                   const Tensor& out_backprop, const Tensor& filter,
114                   int row_dilation, int col_dilation, int row_stride,
115                   int col_stride, const Padding& padding,
116                   const std::vector<int64_t>& explicit_paddings,
117                   Tensor* in_backprop, TensorFormat data_format) {
118     std::vector<int32> strides(4, 1);
119     std::vector<int32> dilations(4, 1);
120 
121     auto input_h = GetTensorDimIndex(data_format, 'H');
122     auto input_w = GetTensorDimIndex(data_format, 'W');
123     strides[input_h] = row_stride;
124     strides[input_w] = col_stride;
125     dilations[input_h] = row_dilation;
126     dilations[input_w] = col_dilation;
127 
128     const TensorShape& input_shape = in_backprop->shape();
129     const TensorShape& filter_shape = filter.shape();
130 
131     ConvBackpropDimensions dims;
132     OP_REQUIRES_OK(
133         ctx, ConvBackpropComputeDimensionsV2(
134                  "Conv2DBackpropInput", /*num_spatial_dims=*/2, input_shape,
135                  filter_shape, out_backprop.shape(), dilations, strides,
136                  padding, explicit_paddings, data_format, &dims));
137 
138     int64_t padding_top = -1, padding_bottom = -1;
139     int64_t padding_left = -1, padding_right = -1;
140     if (padding == EXPLICIT) {
141       GetExplicitPaddingForDim(explicit_paddings, data_format, 'H',
142                                &padding_top, &padding_bottom);
143       GetExplicitPaddingForDim(explicit_paddings, data_format, 'W',
144                                &padding_left, &padding_right);
145     }
146 
147     int64_t expected_out_rows, expected_out_cols;
148     // The function is guaranteed to succeed because we checked the output and
149     // padding was valid earlier.
150     TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
151         dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
152         row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
153         &padding_bottom));
154     DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
155 
156     TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
157         dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
158         col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
159         &padding_right));
160     DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
161 
162     if (std::is_same<Device, GPUDevice>::value) {
163       int64_t size = 1;
164 #define REQUIRES_32BIT(x)                                                   \
165   size *= x;                                                                \
166   OP_REQUIRES(ctx,                                                          \
167               FastBoundsCheck(x, std::numeric_limits<int32>::max()) &&      \
168                   FastBoundsCheck(size, std::numeric_limits<int32>::max()), \
169               errors::InvalidArgument("Tensor too large"))
170 
171       REQUIRES_32BIT(in_backprop->dim_size(0));
172       REQUIRES_32BIT(in_backprop->dim_size(1) + padding_top + padding_bottom);
173       REQUIRES_32BIT(in_backprop->dim_size(2) + padding_left + padding_right);
174       REQUIRES_32BIT(in_backprop->dim_size(3));
175 #undef REQUIRES_32BIT
176     }
177 
178     auto in_backprop_t = in_backprop->tensor<T, 4>();
179     auto out_backprop_t = out_backprop.tensor<T, 4>();
180     auto filter_t = filter.tensor<T, 4>();
181 
182     // WARNING: Need to swap row/col, padding_top/padding_left, and
183     // padding_bottom/padding_right when calling Eigen. Eigen expects tensors
184     // in NWHC format, but Tensorflow uses NHWC.
185 
186     if (padding != EXPLICIT) {
187       // If padding was not explicitly defined, Eigen spatial convolution
188       // backward input will infer correct forward paddings from input tensors.
189       functor::SpatialConvolutionBackwardInputFunc<Device, T>()(
190           ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
191           col_stride, row_stride, col_dilation, row_dilation);
192     } else {
193       functor::SpatialConvolutionBackwardInputWithExplicitPaddingFunc<Device,
194                                                                       T>()(
195           ctx->eigen_device<Device>(), in_backprop_t, filter_t, out_backprop_t,
196           in_backprop_t.dimension(2) + (padding_left + padding_right),
197           in_backprop_t.dimension(1) + (padding_top + padding_bottom),
198           col_stride, row_stride, col_dilation, row_dilation, padding_top,
199           padding_left);
200     }
201   }
202 };
203 
204 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on CPU.
205 template <typename T>
206 struct LaunchConv2DBackpropInputOp<CPUDevice, T> {
207   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
208                   const Tensor& out_backprop, const Tensor& filter,
209                   int row_dilation, int col_dilation, int row_stride,
210                   int col_stride, const Padding& padding,
211                   const std::vector<int64_t>& explicit_paddings,
212                   Tensor* in_backprop, TensorFormat data_format) {
213     LaunchConv2DBackpropInputOpImpl<CPUDevice, T> launcher;
214     launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
215              row_dilation, col_dilation, row_stride, col_stride, padding,
216              explicit_paddings, in_backprop, data_format);
217   }
218 };
219 
220 #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
221 template <typename Device, class T>
222 struct LaunchXsmmBackwardInputConvolution {
223   bool operator()(OpKernelContext* context, const Device& d,
224                   typename TTypes<T, 4>::Tensor input_backward,
225                   typename TTypes<T, 4>::ConstTensor kernel,
226                   typename TTypes<T, 4>::ConstTensor output_backward,
227                   int input_rows, int input_cols, int row_stride,
228                   int col_stride, int pad_h, int pad_w,
229                   TensorFormat data_format) const {
230     return false;
231   }
232 };
233 
234 template <>
235 struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
236   bool operator()(OpKernelContext* context, const CPUDevice& d,
237                   typename TTypes<float, 4>::Tensor input_backward,
238                   typename TTypes<float, 4>::ConstTensor kernel,
239                   typename TTypes<float, 4>::ConstTensor output_backward,
240                   int input_rows, int input_cols, int row_stride,
241                   int col_stride, int pad_h, int pad_w,
242                   TensorFormat data_format) const {
243     auto batch = input_backward.dimension(0);
244     auto in_depth = input_backward.dimension(3);
245     auto out_depth = output_backward.dimension(3);
246     auto filter_rows = kernel.dimension(0);
247     auto filter_cols = kernel.dimension(1);
248     auto num_threads =
249         context->device()->tensorflow_cpu_worker_threads()->num_threads;
250     // See libxsmm_dnn.h for this struct definition.
251     libxsmm_dnn_conv_desc desc;
252     desc.N = batch;
253     desc.C = in_depth;
254     desc.H = input_rows;
255     desc.W = input_cols;
256     desc.K = out_depth;
257     desc.R = filter_rows;
258     desc.S = filter_cols;
259     desc.u = row_stride;
260     desc.v = col_stride;
261     desc.pad_h = pad_h;
262     desc.pad_w = pad_w;
263     desc.pad_h_in = 0;
264     desc.pad_w_in = 0;
265     desc.pad_h_out = 0;
266     desc.pad_w_out = 0;
267     desc.threads = num_threads;
268     desc.algo = LIBXSMM_DNN_CONV_ALGO_DIRECT;
269     desc.buffer_format = LIBXSMM_DNN_TENSOR_FORMAT_NHWC;
270     desc.filter_format =
271         LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM;  // LIBXSMM_DNN_TENSOR_FORMAT_RSCK;
272     desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
273     desc.options = LIBXSMM_DNN_CONV_OPTION_OVERWRITE;
274     desc.datatype_out = LIBXSMM_DNN_DATATYPE_F32;
275     desc.datatype_in = LIBXSMM_DNN_DATATYPE_F32;
276     auto input_ptr = input_backward.data();
277     auto filter_ptr = kernel.data();
278     auto output_ptr = output_backward.data();
279 
280     bool success = functor::XsmmBkwInputConv2D<CPUDevice, float>()(
281         context, desc, input_ptr, filter_ptr, output_ptr);
282     return success;
283   }
284 };
285 #endif
286 
287 template <typename T>
288 struct Conv2DCustomBackpropInputMatMulFunctor {
289   using MatrixMap = Eigen::Map<
290       Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
291   using ConstMatrixMap = Eigen::Map<
292       const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
293 
294   void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
295                   const int filter_total_size, const int output_image_size,
296                   const int dims_out_depth, T* im2col_buf) {
297     // Compute gradient into 'im2col_buf'.
298     MatrixMap C(im2col_buf, output_image_size, filter_total_size);
299 
300     ConstMatrixMap A(out_data, output_image_size, dims_out_depth);
301     ConstMatrixMap B(filter_data, filter_total_size, dims_out_depth);
302 
303     C.noalias() = A * B.transpose();
304   }
305 };
306 
307 #if defined(TENSORFLOW_USE_MKLDNN_CONTRACTION_KERNEL)
308 template <>
309 struct Conv2DCustomBackpropInputMatMulFunctor<float> {
310   using T = float;
311 
312   void operator()(OpKernelContext* ctx, const T* out_data, const T* filter_data,
313                   const int filter_total_size, const int output_image_size,
314                   const int dims_out_depth, T* im2col_buf) {
315     // Inputs are in RowMajor order.
316     //   im2col      = out_data    * filter_data^T
317     //   [ois x fts] = [ois x dod] * [fts x dod]^T
318     //
319     // Dimension names:
320     //   out_image_size    -> ois
321     //   filter_total_size -> fts
322     //   dims_out_depth    -> dod
323 
324     const int m = output_image_size;
325     const int n = filter_total_size;
326     const int k = dims_out_depth;  // contraction dim
327 
328     const char transposeA = 'N';  // sgemm(A) == filter_data
329     const char transposeB = 'T';  // sgemm(B) == out_data
330 
331     const int ldA = dims_out_depth;
332     const int ldB = dims_out_depth;
333     const int ldC = filter_total_size;
334 
335     const float alpha = 1.0;
336     const float beta = 0.0;
337 
338     // dnnl_sgemm code can't be instrumented with msan.
339     ANNOTATE_MEMORY_IS_INITIALIZED(
340         im2col_buf, filter_total_size * output_image_size * sizeof(T));
341 
342     dnnl_status_t st =
343         dnnl_sgemm(transposeA, transposeB, m, n, k, alpha, out_data, ldA,
344                    filter_data, ldB, beta, im2col_buf, ldC);
345 
346     OP_REQUIRES(
347         ctx, st == 0,
348         errors::Internal("Failed to call dnnl_sgemm. Error code: ", st));
349   }
350 };
351 #endif
352 
353 template <typename Device, class T>
354 class Conv2DBackpropInputOp : public OpKernel {
355  public:
356   explicit Conv2DBackpropInputOp(OpKernelConstruction* context)
357       : OpKernel(context) {
358     string data_format;
359     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
360     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
361                 errors::InvalidArgument("Invalid data format"));
362 
363     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
364     OP_REQUIRES(context, strides_.size() == 4,
365                 errors::InvalidArgument("Sliding window strides field must "
366                                         "specify 4 dimensions"));
367     int stride_n = GetTensorDim(strides_, data_format_, 'N');
368     int stride_c = GetTensorDim(strides_, data_format_, 'C');
369     int stride_h = GetTensorDim(strides_, data_format_, 'H');
370     int stride_w = GetTensorDim(strides_, data_format_, 'W');
371     OP_REQUIRES(
372         context, (stride_n == 1 && stride_c == 1),
373         errors::Unimplemented("Current implementation does not yet support "
374                               "strides in the batch and depth dimensions."));
375     OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
376                 errors::InvalidArgument(
377                     "Row and column strides should be larger than 0."));
378 
379     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
380     OP_REQUIRES(context, dilations_.size() == 4,
381                 errors::InvalidArgument("Sliding window dilations field must "
382                                         "specify 4 dimensions"));
383     int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
384     int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
385     int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
386     int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
387     OP_REQUIRES(
388         context, (dilation_n == 1 && dilation_c == 1),
389         errors::Unimplemented("Current implementation does not yet support "
390                               "dilations in the batch and depth dimensions."));
391     OP_REQUIRES(
392         context, dilation_h > 0 && dilation_w > 0,
393         errors::InvalidArgument("Dilated rates should be larger than 0."));
394 
395     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
396     OP_REQUIRES_OK(context,
397                    context->GetAttr("explicit_paddings", &explicit_paddings_));
398     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
399                                               /*num_dims=*/4, data_format_));
400 
401     OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
402     cudnn_use_autotune_ = CudnnUseAutotune();
403 
404     if (std::is_same<Device, CPUDevice>::value ||
405         std::is_same<T, int32>::value) {
406       OP_REQUIRES(
407           context, data_format_ == FORMAT_NHWC,
408           errors::InvalidArgument("Conv2DBackpropInputOp [CPU or GPU(int32)] "
409                                   "only supports NHWC data format."));
410 
411       // TODO(yangzihao): Add a CPU implementation for dilated convolution.
412       OP_REQUIRES(
413           context, (dilation_h == 1 && dilation_w == 1),
414           errors::InvalidArgument(
415               "Conv2DBackpropInputOp [CPU or GPU(int32)] not yet support "
416               "dilation rates larger than 1."));
417     }
418   }
419 
420   void Compute(OpKernelContext* context) override {
421     const Tensor& input_sizes = context->input(0);
422     const Tensor& filter = context->input(1);
423     const Tensor& out_backprop = context->input(2);
424 
425     OP_REQUIRES(
426         context, out_backprop.dims() == 4,
427         errors::InvalidArgument("input_sizes must be 4-dimensional, got: ",
428                                 out_backprop.dims()));
429 
430     TensorShape input_shape;
431     OP_REQUIRES_OK(context,
432                    Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
433                                                    out_backprop.shape(),
434                                                    data_format_, &input_shape));
435 
436     Tensor* in_backprop = nullptr;
437     OP_REQUIRES_OK(context,
438                    context->allocate_output(0, input_shape, &in_backprop));
439 
440     // If there is nothing to compute, return.
441     if (input_shape.num_elements() == 0) {
442       return;
443     }
444 
445     // If shapes are valid but `out_backprop` is empty, in_backprop should be
446     // set to all zeros.  Otherwise, cudnn/dnnl fail with an empty input.
447     if (out_backprop.NumElements() == 0) {
448       functor::SetZeroFunctor<Device, T> set_zero;
449       set_zero(context->eigen_device<Device>(),
450                in_backprop->template flat<T>());
451       return;
452     }
453 
454     // For now we take the stride from the second and third dimensions only (we
455     // do not support striding on the batch or depth dimension).
456     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
457     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
458     const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
459     const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
460 
461     VLOG(2) << "Conv2DBackpropInput:"
462             << " input: " << input_shape.DebugString()
463             << " filter:" << filter.shape().DebugString()
464             << " out_backprop: " << out_backprop.shape().DebugString()
465             << " strides: [" << stride_rows << ", " << stride_cols << "]"
466             << " dilations: [" << dilation_rows << ", " << dilation_cols << "]";
467 
468     LaunchConv2DBackpropInputOp<Device, T> launch;
469     launch(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
470            dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
471            explicit_paddings_, in_backprop, data_format_);
472   }
473 
474  private:
475   std::vector<int32> dilations_;
476   std::vector<int32> strides_;
477   TensorFormat data_format_;
478   Padding padding_;
479   std::vector<int64_t> explicit_paddings_;
480 
481   bool use_cudnn_ = false;
482   bool cudnn_use_autotune_ = false;
483 
484   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp);
485 };
486 
487 // Based on implementation written by Yangqing Jia (jiayq).
488 template <typename Device, class T>
489 class Conv2DCustomBackpropInputOp : public OpKernel {
490  public:
491   explicit Conv2DCustomBackpropInputOp(OpKernelConstruction* context)
492       : OpKernel(context) {
493     string data_format;
494     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
495     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
496                 errors::InvalidArgument("Invalid data format"));
497     OP_REQUIRES(context, data_format_ == FORMAT_NHWC,
498                 errors::InvalidArgument(
499                     "Conv2DCustomBackpropInputOp only supports NHWC."));
500     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
501     OP_REQUIRES(context, strides_.size() == 4,
502                 errors::InvalidArgument("Sliding window strides field must "
503                                         "specify 4 dimensions"));
504     OP_REQUIRES(
505         context, (strides_[0] == 1 && strides_[3] == 1),
506         errors::Unimplemented("Current implementation does not yet support "
507                               "strides in the batch and depth dimensions."));
508     OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
509                 errors::InvalidArgument(
510                     "Row and column strides should be larger than 0."));
511     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
512     OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
513     OP_REQUIRES(context, dilations_.size() == 4,
514                 errors::InvalidArgument("Sliding window dilations field must "
515                                         "specify 4 dimensions"));
516     OP_REQUIRES(
517         context, (dilations_[0] == 1 && dilations_[3] == 1),
518         errors::Unimplemented("Current implementation does not yet support "
519                               "dilations in the batch and depth dimensions."));
520     // TODO(yangzihao): Add a CPU implementation for dilated convolution.
521     OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
522                 errors::InvalidArgument(
523                     "Current libxsmm and customized CPU implementations do "
524                     "not yet support dilation rates larger than 1."));
525     OP_REQUIRES_OK(context,
526                    context->GetAttr("explicit_paddings", &explicit_paddings_));
527     OP_REQUIRES_OK(context, CheckValidPadding(padding_, explicit_paddings_,
528                                               /*num_dims=*/4, data_format_));
529   }
530 
531   void Compute(OpKernelContext* context) override {
532     const Tensor& input_sizes = context->input(0);
533     const Tensor& filter = context->input(1);
534     const Tensor& out_backprop = context->input(2);
535     OP_REQUIRES(
536         context, out_backprop.dims() == 4,
537         errors::InvalidArgument("input_sizes must be 4-dimensional, got: ",
538                                 out_backprop.dims()));
539 
540     TensorShape input_shape;
541     OP_REQUIRES_OK(context,
542                    Conv2DBackpropComputeInputShape(input_sizes, filter.shape(),
543                                                    out_backprop.shape(),
544                                                    data_format_, &input_shape));
545 
546     ConvBackpropDimensions dims;
547     OP_REQUIRES_OK(context,
548                    ConvBackpropComputeDimensionsV2(
549                        "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2,
550                        input_shape, filter.shape(), out_backprop.shape(),
551                        /*dilations=*/{1, 1, 1, 1}, strides_, padding_,
552                        explicit_paddings_, data_format_, &dims));
553 
554     OP_REQUIRES(context, dims.in_depth == filter.shape().dim_size(2),
555                 errors::InvalidArgument(
556                     "Gradients for grouped convolutions are not "
557                     "supported on CPU. Please file a feature request if you "
558                     "run into this issue. Computed input depth ",
559                     dims.in_depth, " doesn't match filter input depth ",
560                     filter.shape().dim_size(2)));
561     OP_REQUIRES(
562         context, dims.out_depth == filter.shape().dim_size(3),
563         errors::InvalidArgument("Computed output depth ", dims.out_depth,
564                                 " doesn't match filter output depth ",
565                                 filter.shape().dim_size(3)));
566 
567     Tensor* in_backprop = nullptr;
568     OP_REQUIRES_OK(context,
569                    context->allocate_output(0, input_shape, &in_backprop));
570 
571     // If there is nothing to compute, return.
572     if (input_shape.num_elements() == 0) {
573       return;
574     }
575 
576     // If shapes are valid but `out_backprop` is empty, in_backprop should be
577     // set to all zeros.  Otherwise, cudnn/dnnl fail with an empty input.
578     if (out_backprop.NumElements() == 0) {
579       functor::SetZeroFunctor<Device, T> set_zero;
580       set_zero(context->eigen_device<Device>(),
581                in_backprop->template flat<T>());
582       return;
583     }
584 
585 // TODO(ezhulenev): Remove custom kernel and move XSMM support to
586 // LaunchConv2DBackpropInputOp functor.
587 #if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \
588     defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS
589     int64 pad_top, pad_bottom;
590     int64 pad_left, pad_right;
591     OP_REQUIRES_OK(
592         context,
593         GetWindowedOutputSizeVerbose(
594             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
595             dims.spatial_dims[0].stride, padding_,
596             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
597     OP_REQUIRES_OK(
598         context,
599         GetWindowedOutputSizeVerbose(
600             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
601             dims.spatial_dims[1].stride, padding_,
602             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
603 
604     if (pad_left == pad_right && pad_top == pad_bottom) {
605       if (LaunchXsmmBackwardInputConvolution<Device, T>()(
606               context, context->eigen_device<Device>(),
607               in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
608               out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
609               dims.spatial_dims[1].input_size,
610               static_cast<int>(dims.spatial_dims[0].stride),
611               static_cast<int>(dims.spatial_dims[1].stride),
612               static_cast<int>(pad_top), static_cast<int>(pad_left),
613               data_format_)) {
614         return;
615       }
616     }
617 #else
618     int64_t pad_top, pad_bottom;
619     int64_t pad_left, pad_right;
620 #endif
621     if (padding_ == Padding::EXPLICIT) {
622       pad_top = explicit_paddings_[2];
623       pad_bottom = explicit_paddings_[3];
624       pad_left = explicit_paddings_[4];
625       pad_right = explicit_paddings_[5];
626     }
627     OP_REQUIRES_OK(
628         context,
629         GetWindowedOutputSizeVerbose(
630             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
631             dims.spatial_dims[0].stride, padding_,
632             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
633     OP_REQUIRES_OK(
634         context,
635         GetWindowedOutputSizeVerbose(
636             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
637             dims.spatial_dims[1].stride, padding_,
638             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
639 
640     // The total dimension size of each kernel.
641     const int filter_total_size = dims.spatial_dims[0].filter_size *
642                                   dims.spatial_dims[1].filter_size *
643                                   dims.in_depth;
644     // The output image size is the spatial size of the output.
645     const int output_image_size =
646         dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size;
647 
648     // TODO(andydavis) Get L2/L3 cache sizes from device.
649     const size_t l2_cache_size = 256LL << 10;
650     const size_t l3_cache_size = 30LL << 20;
651 
652     // Use L3 cache size as target working set size.
653     const size_t target_working_set_size = l3_cache_size / sizeof(T);
654 
655     // Calculate size of matrices involved in MatMul: C = A x B.
656     const size_t size_A = output_image_size * dims.out_depth;
657 
658     const size_t size_B = filter_total_size * dims.out_depth;
659 
660     const size_t size_C = output_image_size * filter_total_size;
661 
662     const size_t work_unit_size = size_A + size_B + size_C;
663 
664     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
665 
666     // Calculate per-thread work unit size.
667     const size_t thread_work_unit_size =
668         work_unit_size / worker_threads.num_threads;
669 
670     // Set minimum per-thread work unit size to size of L2 cache.
671     const size_t min_thread_work_unit_size = l2_cache_size / sizeof(T);
672 
673     // Use parallel tensor contractions if there is no batching, or if the
674     // minimum per-thread work unit size threshold has been exceeded.
675     // Otherwise, revert to multiple single-threaded matmul ops running in
676     // parallel to keep all threads busy.
677     // TODO(andydavis) Explore alternatives to branching the code in this way
678     // (i.e. run multiple, parallel tensor contractions in another thread pool).
679     const bool use_parallel_contraction =
680         dims.batch_size == 1 ||
681         thread_work_unit_size >= min_thread_work_unit_size;
682 
683     OP_REQUIRES(
684         context, work_unit_size > 0,
685         errors::InvalidArgument("input, filter_sizes and out_backprop tensors "
686                                 "must all have at least 1 element"));
687 
688     const size_t shard_size =
689         use_parallel_contraction
690             ? 1
691             : (target_working_set_size + work_unit_size - 1) / work_unit_size;
692 
693     Tensor col_buffer;
694     OP_REQUIRES_OK(context,
695                    context->allocate_temp(
696                        DataTypeToEnum<T>::value,
697                        TensorShape({static_cast<int64_t>(shard_size),
698                                     static_cast<int64_t>(output_image_size),
699                                     static_cast<int64_t>(filter_total_size)}),
700                        &col_buffer));
701 
702     // The input offset corresponding to a single input image.
703     const int input_offset = dims.spatial_dims[0].input_size *
704                              dims.spatial_dims[1].input_size * dims.in_depth;
705     // The output offset corresponding to a single output image.
706     const int output_offset = dims.spatial_dims[0].output_size *
707                               dims.spatial_dims[1].output_size * dims.out_depth;
708 
709     const T* filter_data = filter.template flat<T>().data();
710     T* col_buffer_data = col_buffer.template flat<T>().data();
711     const T* out_backprop_data = out_backprop.template flat<T>().data();
712 
713     auto in_backprop_flat = in_backprop->template flat<T>();
714     T* input_backprop_data = in_backprop_flat.data();
715     in_backprop_flat.device(context->eigen_device<Device>()) =
716         in_backprop_flat.constant(T(0));
717 
718     if (use_parallel_contraction) {
719       typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
720                                Eigen::Unaligned>
721           TensorMap;
722       typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
723                                Eigen::Unaligned>
724           ConstTensorMap;
725 
726       // Initialize contraction dims (we need to transpose 'B' below).
727       Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
728       contract_dims[0].first = 1;
729       contract_dims[0].second = 1;
730 
731       for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
732         // Compute gradient into col_buffer.
733         TensorMap C(col_buffer_data, output_image_size, filter_total_size);
734 
735         ConstTensorMap A(out_backprop_data + output_offset * image_id,
736                          output_image_size, dims.out_depth);
737         ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
738 
739         C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
740 
741         Col2im<T>(
742             col_buffer_data, dims.in_depth, dims.spatial_dims[0].input_size,
743             dims.spatial_dims[1].input_size, dims.spatial_dims[0].filter_size,
744             dims.spatial_dims[1].filter_size, pad_top, pad_left, pad_bottom,
745             pad_right, dims.spatial_dims[0].stride, dims.spatial_dims[1].stride,
746             input_backprop_data);
747 
748         input_backprop_data += input_offset;
749       }
750     } else {
751       for (int image_id = 0; image_id < dims.batch_size;
752            image_id += shard_size) {
753         const int shard_limit =
754             std::min(static_cast<int>(shard_size),
755                      static_cast<int>(dims.batch_size) - image_id);
756 
757         auto shard = [&context, &dims, &pad_top, &pad_left, &pad_bottom,
758                       &pad_right, &output_image_size, &filter_total_size,
759                       &input_backprop_data, &col_buffer_data,
760                       &out_backprop_data, &filter_data, &input_offset,
761                       &output_offset, &size_C](int64_t start, int64_t limit) {
762           for (int shard_id = start; shard_id < limit; ++shard_id) {
763             T* im2col_buf = col_buffer_data + shard_id * size_C;
764             T* input_data = input_backprop_data + shard_id * input_offset;
765             const T* out_data = out_backprop_data + shard_id * output_offset;
766 
767             Conv2DCustomBackpropInputMatMulFunctor<T>()(
768                 context, out_data, filter_data, filter_total_size,
769                 output_image_size, dims.out_depth, im2col_buf);
770 
771             Col2im<T>(im2col_buf, dims.in_depth,
772                       dims.spatial_dims[0].input_size,
773                       dims.spatial_dims[1].input_size,
774                       dims.spatial_dims[0].filter_size,
775                       dims.spatial_dims[1].filter_size, pad_top, pad_left,
776                       pad_bottom, pad_right, dims.spatial_dims[0].stride,
777                       dims.spatial_dims[1].stride, input_data);
778           }
779         };
780         Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
781               work_unit_size, shard);
782 
783         input_backprop_data += input_offset * shard_limit;
784         out_backprop_data += output_offset * shard_limit;
785       }
786     }
787   }
788 
789  private:
790   std::vector<int32> dilations_;
791   std::vector<int32> strides_;
792   Padding padding_;
793   std::vector<int64_t> explicit_paddings_;
794   TensorFormat data_format_;
795 
796   TF_DISALLOW_COPY_AND_ASSIGN(Conv2DCustomBackpropInputOp);
797 };
798 
799 // TODO(ezhulenev): Add a cost model to switch between custom/Eigen ops.
800 #define DEFAULT_CONV_2D_BACKPROP_CPU_OP Conv2DCustomBackpropInputOp
801 
802 #define REGISTER_CONV_2D_BACKPROP_CPU_KERNELS(T)                             \
803   REGISTER_KERNEL_BUILDER(                                                   \
804       Name("Conv2DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
805       DEFAULT_CONV_2D_BACKPROP_CPU_OP<CPUDevice, T>);                        \
806   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
807                               .Device(DEVICE_CPU)                            \
808                               .Label("custom")                               \
809                               .TypeConstraint<T>("T"),                       \
810                           Conv2DCustomBackpropInputOp<CPUDevice, T>);        \
811   REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")                        \
812                               .Device(DEVICE_CPU)                            \
813                               .Label("eigen_tensor")                         \
814                               .TypeConstraint<T>("T"),                       \
815                           Conv2DBackpropInputOp<CPUDevice, T>);
816 
817 }  // namespace tensorflow
818 
819 #endif  // TENSORFLOW_CORE_KERNELS_CONV_GRAD_INPUT_OPS_H_
820