xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/conv_grad_input_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/nn_ops.cc.
17 
18 #include "tensorflow/core/kernels/conv_grad_input_ops.h"
19 
20 #include <utility>
21 
22 #include "tensorflow/core/profiler/lib/scoped_annotation.h"
23 
24 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
25 #include "tensorflow/core/protobuf/autotuning.pb.h"
26 #include "tensorflow/core/util/autotune_maps/conv_parameters.h"
27 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
28 
29 namespace tensorflow {
30 
31 typedef Eigen::ThreadPoolDevice CPUDevice;
32 typedef Eigen::GpuDevice GPUDevice;
33 
34 // To be used inside depthwise_conv_grad_op.cc.
35 template struct LaunchConv2DBackpropInputOp<CPUDevice, bfloat16>;
36 template struct LaunchConv2DBackpropInputOp<CPUDevice, Eigen::half>;
37 template struct LaunchConv2DBackpropInputOp<CPUDevice, float>;
38 template struct LaunchConv2DBackpropInputOp<CPUDevice, double>;
39 
40 // GPU definitions.
41 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
42 // The slow version (but compiles for GPU)
43 
44 // A dummy type to group forward backward data autotune results together.
45 struct ConvBackwardDataAutotuneGroup {
nametensorflow::ConvBackwardDataAutotuneGroup46   static string name() { return "ConvBwdData"; }
47 };
48 
49 typedef AutotuneSingleton<ConvBackwardDataAutotuneGroup, ConvParameters,
50                           AutotuneEntry<se::dnn::ConvOp>>
51     AutotuneConvBwdData;
52 
53 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
54 // Computes backprop input using Eigen::SpatialConvolutionBackwardInput on GPU
55 // for int32 inputs.
56 template <>
57 struct LaunchConv2DBackpropInputOp<GPUDevice, int32> {
operator ()tensorflow::LaunchConv2DBackpropInputOp58   void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
59                   const Tensor& out_backprop, const Tensor& filter,
60                   int row_dilation, int col_dilation, int row_stride,
61                   int col_stride, const Padding& padding,
62                   const std::vector<int64_t>& explicit_paddings,
63                   Tensor* in_backprop, TensorFormat data_format) {
64     LaunchConv2DBackpropInputOpImpl<GPUDevice, int32> launcher;
65     launcher(ctx, use_cudnn, cudnn_use_autotune, out_backprop, filter,
66              row_dilation, col_dilation, row_stride, col_stride, padding,
67              explicit_paddings, in_backprop, data_format);
68   }
69 };
70 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
71 
72 template <typename T>
operator ()(OpKernelContext * ctx,bool use_cudnn,bool cudnn_use_autotune,const Tensor & out_backprop,const Tensor & filter,int row_dilation,int col_dilation,int row_stride,int col_stride,const Padding & padding,const std::vector<int64_t> & explicit_paddings,Tensor * in_backprop,TensorFormat data_format)73 void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
74     OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
75     const Tensor& out_backprop, const Tensor& filter, int row_dilation,
76     int col_dilation, int row_stride, int col_stride, const Padding& padding,
77     const std::vector<int64_t>& explicit_paddings, Tensor* in_backprop,
78     TensorFormat data_format) {
79   using se::dnn::AlgorithmConfig;
80   using se::dnn::AlgorithmDesc;
81   using se::dnn::ProfileResult;
82 
83   std::vector<int32> strides(4, 1);
84   std::vector<int32> dilations(4, 1);
85   auto input_h = GetTensorDimIndex(data_format, 'H');
86   auto input_w = GetTensorDimIndex(data_format, 'W');
87   strides[input_h] = row_stride;
88   strides[input_w] = col_stride;
89   dilations[input_h] = row_dilation;
90   dilations[input_w] = col_dilation;
91   TensorShape input_shape = in_backprop->shape();
92 
93   const TensorShape& filter_shape = filter.shape();
94   ConvBackpropDimensions dims;
95   OP_REQUIRES_OK(
96       ctx, ConvBackpropComputeDimensionsV2(
97                "Conv2DSlowBackpropInput", /*num_spatial_dims=*/2, input_shape,
98                filter_shape, out_backprop.shape(), dilations, strides, padding,
99                explicit_paddings, data_format, &dims));
100 
101   int64_t padding_top = -1, padding_bottom = -1;
102   int64_t padding_left = -1, padding_right = -1;
103   if (padding == EXPLICIT) {
104     GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', &padding_top,
105                              &padding_bottom);
106     GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', &padding_left,
107                              &padding_right);
108   }
109   int64_t expected_out_rows, expected_out_cols;
110   // The function is guaranteed to succeed because we checked the output and
111   // padding was valid earlier.
112   TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
113       dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
114       row_dilation, row_stride, padding, &expected_out_rows, &padding_top,
115       &padding_bottom));
116   DCHECK_EQ(dims.spatial_dims[0].output_size, expected_out_rows);
117   TF_CHECK_OK(GetWindowedOutputSizeVerboseV2(
118       dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
119       col_dilation, col_stride, padding, &expected_out_cols, &padding_left,
120       &padding_right));
121   DCHECK_EQ(dims.spatial_dims[1].output_size, expected_out_cols);
122 
123   auto* stream = ctx->op_device_context()->stream();
124   OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
125 
126   if (!use_cudnn) {
127     ctx->SetStatus(errors::Unimplemented(
128         "Conv2DBackpropInput for GPU is not currently supported "
129         "without cudnn"));
130     return;
131   }
132 
133   // If the filter in-depth (filter_shape.dim_size(2)) is 1 and smaller than the
134   // input depth, it's a depthwise convolution. More generally, if the filter
135   // in-depth divides but is smaller than the input depth, it is a grouped
136   // convolution.
137   bool is_grouped_convolution = filter_shape.dim_size(2) != dims.in_depth;
138   if (dims.spatial_dims[0].filter_size == 1 &&
139       dims.spatial_dims[1].filter_size == 1 && !is_grouped_convolution &&
140       dims.spatial_dims[0].stride == 1 && dims.spatial_dims[1].stride == 1 &&
141       data_format == FORMAT_NHWC && (padding == VALID || padding == SAME)) {
142     // 1x1 filter, so call cublas directly.
143     const uint64 m = dims.batch_size * dims.spatial_dims[0].input_size *
144                      dims.spatial_dims[1].input_size;
145     const uint64 k = dims.out_depth;
146     const uint64 n = dims.in_depth;
147 
148     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
149                                 out_backprop.template flat<T>().size());
150     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
151                                 filter.template flat<T>().size());
152     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
153                                 in_backprop->template flat<T>().size());
154 
155     auto transpose = se::blas::Transpose::kTranspose;
156     auto no_transpose = se::blas::Transpose::kNoTranspose;
157 
158     OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(
159                             transpose, no_transpose, n, m, k, b_ptr, k, a_ptr,
160                             k, &c_ptr, n, se::blas::kDefaultComputePrecision));
161     return;
162   } else if (dims.spatial_dims[0].filter_size ==
163                  dims.spatial_dims[0].input_size &&
164              dims.spatial_dims[1].filter_size ==
165                  dims.spatial_dims[1].input_size &&
166              !is_grouped_convolution && padding == VALID &&
167              data_format == FORMAT_NHWC) {
168     // The input data and filter have the same height/width, and we are not
169     // using grouped convolution, so call cublas directly.
170     const uint64 m = dims.batch_size;
171     const uint64 k = dims.out_depth;
172     const uint64 n = dims.spatial_dims[0].input_size *
173                      dims.spatial_dims[1].input_size * dims.in_depth;
174 
175     auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
176                                 out_backprop.template flat<T>().size());
177     auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
178                                 filter.template flat<T>().size());
179     auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
180                                 in_backprop->template flat<T>().size());
181 
182     auto transpose = se::blas::Transpose::kTranspose;
183     auto no_transpose = se::blas::Transpose::kNoTranspose;
184 
185     OP_REQUIRES_OK(ctx, stream->ThenBlasGemm(
186                             transpose, no_transpose, n, m, k, b_ptr, k, a_ptr,
187                             k, &c_ptr, n, se::blas::kDefaultComputePrecision));
188     return;
189   }
190 
191   const int64_t common_padding_rows = std::min(padding_top, padding_bottom);
192   const int64_t common_padding_cols = std::min(padding_left, padding_right);
193   TensorShape compatible_input_shape;
194   if (padding_top != padding_bottom || padding_left != padding_right) {
195     // Pad the input in the same way we did during the forward pass, so that
196     // cuDNN or MIOpen receives the same input during the backward pass function
197     // as it did during the forward pass function.
198     const int64_t padding_rows_diff = std::abs(padding_bottom - padding_top);
199     const int64_t padding_cols_diff = std::abs(padding_right - padding_left);
200     const int64_t new_in_rows =
201         dims.spatial_dims[0].input_size + padding_rows_diff;
202     const int64_t new_in_cols =
203         dims.spatial_dims[1].input_size + padding_cols_diff;
204     compatible_input_shape = ShapeFromFormat(
205         data_format, dims.batch_size, new_in_rows, new_in_cols, dims.in_depth);
206   } else {
207     compatible_input_shape = input_shape;
208   }
209 
210   CHECK(common_padding_rows >= 0 && common_padding_cols >= 0)  // Crash OK
211       << "Negative row or col paddings: (" << common_padding_rows << ", "
212       << common_padding_cols << ")";
213 
214   // The Tensor Core in NVIDIA Volta+ GPUs supports efficient convolution with
215   // fp16 in NHWC data layout. In all other configurations it's more efficient
216   // to run computation in NCHW data format.
217   const bool compute_in_nhwc = DataTypeToEnum<T>::value == DT_HALF &&
218                                stream->GetCudaComputeCapability().IsAtLeast(
219                                    se::CudaComputeCapability::VOLTA);
220 
221   // We only do one directional conversion: NHWC->NCHW. We never convert in the
222   // other direction. Grappler layout optimizer selects the preferred layout and
223   // adds necessary annotations to the graph.
224   const TensorFormat compute_data_format =
225       (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
226                                                       : FORMAT_NCHW;
227 
228   VLOG(3) << "Compute Conv2DBackpropInput with cuDNN:"
229           << " data_format=" << ToString(data_format)
230           << " compute_data_format=" << ToString(compute_data_format);
231 
232   constexpr auto kComputeInNHWC =
233       std::make_tuple(se::dnn::DataLayout::kBatchYXDepth,
234                       se::dnn::FilterLayout::kOutputYXInput);
235   constexpr auto kComputeInNCHW =
236       std::make_tuple(se::dnn::DataLayout::kBatchDepthYX,
237                       se::dnn::FilterLayout::kOutputInputYX);
238 
239   se::dnn::DataLayout compute_data_layout;
240   se::dnn::FilterLayout filter_layout;
241 
242   std::tie(compute_data_layout, filter_layout) =
243       compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW;
244 
245   se::dnn::BatchDescriptor input_desc;
246   input_desc.set_count(dims.batch_size)
247       .set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
248       .set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
249       .set_feature_map_count(dims.in_depth)
250       .set_layout(compute_data_layout);
251   se::dnn::BatchDescriptor output_desc;
252   output_desc.set_count(dims.batch_size)
253       .set_height(dims.spatial_dims[0].output_size)
254       .set_width(dims.spatial_dims[1].output_size)
255       .set_feature_map_count(dims.out_depth)
256       .set_layout(compute_data_layout);
257   se::dnn::FilterDescriptor filter_desc;
258   filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
259       .set_input_filter_width(dims.spatial_dims[1].filter_size)
260       .set_input_feature_map_count(filter_shape.dim_size(2))
261       .set_output_feature_map_count(filter_shape.dim_size(3))
262       .set_layout(filter_layout);
263   se::dnn::ConvolutionDescriptor conv_desc;
264   conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
265       .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
266       .set_vertical_filter_stride(dims.spatial_dims[0].stride)
267       .set_horizontal_filter_stride(dims.spatial_dims[1].stride)
268       .set_zero_padding_height(common_padding_rows)
269       .set_zero_padding_width(common_padding_cols)
270       .set_group_count(dims.in_depth / filter_shape.dim_size(2));
271 
272   // Tensorflow filter format: HWIO
273   // cuDNN filter formats: (data format) -> (filter format)
274   //   (1) NCHW -> OIHW
275   //   (2) NHWC -> OHWI
276 
277   Tensor transformed_filter;
278   const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
279     VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
280             << " to " << ToString(dst_format);
281 
282     TensorShape dst_shape =
283         dst_format == FORMAT_OIHW
284             ? TensorShape({filter.dim_size(3), filter.dim_size(2),
285                            filter.dim_size(0), filter.dim_size(1)})
286             : TensorShape({filter.dim_size(3), filter.dim_size(0),
287                            filter.dim_size(1), filter.dim_size(2)});
288 
289     TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
290                                           &transformed_filter));
291     functor::TransformFilter<GPUDevice, T, int, 4>()(
292         ctx->eigen_device<GPUDevice>(), dst_format,
293         To32Bit(filter.tensor<T, 4>()),
294         To32Bit(transformed_filter.tensor<T, 4>()));
295 
296     return OkStatus();
297   };
298 
299   if (compute_data_format == FORMAT_NCHW) {
300     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
301   } else if (compute_data_format == FORMAT_NHWC) {
302     OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
303   } else {
304     ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
305                                            ToString(compute_data_format)));
306     return;
307   }
308 
309   Tensor transformed_out_backprop;
310   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
311     VLOG(4) << "Convert the `out_backprop` tensor from NHWC to NCHW.";
312     TensorShape compute_shape = ShapeFromFormat(
313         compute_data_format, dims.batch_size, dims.spatial_dims[0].output_size,
314         dims.spatial_dims[1].output_size, dims.out_depth);
315     if (dims.out_depth > 1) {
316       OP_REQUIRES_OK(ctx,
317                      ctx->allocate_temp(DataTypeToEnum<T>::value, compute_shape,
318                                         &transformed_out_backprop));
319       functor::NHWCToNCHW<GPUDevice, T, 4>()(
320           ctx->eigen_device<GPUDevice>(), out_backprop.tensor<T, 4>(),
321           transformed_out_backprop.tensor<T, 4>());
322     } else {
323       // If depth <= 1, then just reshape.
324       CHECK(transformed_out_backprop.CopyFrom(out_backprop, compute_shape));
325     }
326   } else {
327     transformed_out_backprop = out_backprop;
328   }
329 
330   Tensor pre_transformed_in_backprop;
331   OP_REQUIRES_OK(
332       ctx, ctx->allocate_temp(
333                DataTypeToEnum<T>::value,
334                ShapeFromFormat(
335                    compute_data_format,
336                    GetTensorDim(compatible_input_shape, data_format, 'N'),
337                    GetTensorDim(compatible_input_shape, data_format, 'H'),
338                    GetTensorDim(compatible_input_shape, data_format, 'W'),
339                    GetTensorDim(compatible_input_shape, data_format, 'C')),
340                &pre_transformed_in_backprop));
341 
342   auto out_backprop_ptr =
343       AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
344                      transformed_out_backprop.template flat<T>().size());
345   auto filter_ptr =
346       AsDeviceMemory(transformed_filter.template flat<T>().data(),
347                      transformed_filter.template flat<T>().size());
348   auto in_backprop_ptr =
349       AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
350                      pre_transformed_in_backprop.template flat<T>().size());
351 
352   static int64_t ConvolveBackwardDataScratchSize =
353       GetDnnWorkspaceLimitOrDefault();
354 
355   int device_id = stream->parent()->device_ordinal();
356   DataType dtype = out_backprop.dtype();
357   ConvParameters conv_parameters = {
358       dims.batch_size,                     // batch
359       dims.in_depth,                       // in_depths
360       {{input_desc.height(),               // in_rows
361         input_desc.width()}},              // in_cols
362       compute_data_format,                 // compute_data_format
363       dims.out_depth,                      // out_depths
364       {{dims.spatial_dims[0].filter_size,  // filter_rows
365         dims.spatial_dims[1].filter_size,  // filter_cols
366         filter_shape.dim_size(2)}},        // filter_depths
367       {{dims.spatial_dims[0].dilation,     // dilation_rows
368         dims.spatial_dims[1].dilation}},   // dilation_cols
369       {{dims.spatial_dims[0].stride,       // stride_rows
370         dims.spatial_dims[1].stride}},     // stride_cols
371       {{common_padding_rows,               // padding_rows
372         common_padding_cols}},             // padding_cols
373       dtype,                               // tensor data type
374       device_id,                           // device_id
375       conv_desc.group_count()              // group_count
376   };
377 
378   auto entry_or = AutotuneUnfusedConv(
379       cudnn_use_autotune, AutotuneConvBwdData::GetInstance(), conv_parameters,
380       ctx, se::dnn::ConvolutionKind::BACKWARD_DATA, input_desc, in_backprop_ptr,
381       filter_desc, filter_ptr, conv_desc, output_desc, out_backprop_ptr,
382       ConvolveBackwardDataScratchSize);
383   OP_REQUIRES_OK(ctx, entry_or.status());
384   auto autotune_entry = std::move(entry_or).value();
385 
386   DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
387   Status cudnn_launch_status =
388       LaunchAutotunedConv(autotune_entry, &scratch_allocator,
389                           se::dnn::ConvolutionKind::BACKWARD_DATA, stream,
390                           input_desc, in_backprop_ptr, filter_desc, filter_ptr,
391                           conv_desc, output_desc, out_backprop_ptr);
392   if (!cudnn_launch_status.ok()) {
393     ctx->SetStatus(cudnn_launch_status);
394     return;
395   }
396 
397   if (padding_top != padding_bottom || padding_left != padding_right) {
398     Tensor in_backprop_remove_padding;
399     OP_REQUIRES_OK(
400         ctx, ctx->allocate_temp(
401                  DataTypeToEnum<T>::value,
402                  ShapeFromFormat(compute_data_format,
403                                  GetTensorDim(input_shape, data_format, 'N'),
404                                  GetTensorDim(input_shape, data_format, 'H'),
405                                  GetTensorDim(input_shape, data_format, 'W'),
406                                  GetTensorDim(input_shape, data_format, 'C')),
407                  &in_backprop_remove_padding));
408 
409     // Remove the padding that was added to the input shape above.
410     const int64_t input_pad_top = padding_top - common_padding_rows;
411     const int64_t input_pad_bottom = padding_bottom - common_padding_rows;
412     const int64_t input_pad_left = padding_left - common_padding_cols;
413     const int64_t input_pad_right = padding_right - common_padding_cols;
414     functor::PadInput<GPUDevice, T, int, 4>()(
415         ctx->template eigen_device<GPUDevice>(),
416         To32Bit(const_cast<const Tensor&>(pre_transformed_in_backprop)
417                     .tensor<T, 4>()),
418         {{static_cast<int>(-input_pad_top), static_cast<int>(-input_pad_left)}},
419         {{static_cast<int>(-input_pad_bottom),
420           static_cast<int>(-input_pad_right)}},
421         To32Bit(in_backprop_remove_padding.tensor<T, 4>()), compute_data_format,
422         T{});
423 
424     pre_transformed_in_backprop = in_backprop_remove_padding;
425   }
426 
427   if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) {
428     VLOG(4) << "Convert the output tensor back from NCHW to NHWC.";
429     auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; };
430     functor::NCHWToNHWC<GPUDevice, T, 4>()(
431         ctx->eigen_device<GPUDevice>(),
432         toConstTensor(pre_transformed_in_backprop).template tensor<T, 4>(),
433         in_backprop->tensor<T, 4>());
434   } else {
435     *in_backprop = pre_transformed_in_backprop;
436   }
437 }
438 
439 // Forward declarations of the functor specializations for GPU.
440 namespace functor {
441 #define DECLARE_GPU_SPEC(T)                                             \
442   template <>                                                           \
443   void TransformFilter<GPUDevice, T, int, 4>::operator()(               \
444       const GPUDevice& d, FilterTensorFormat dst_filter_format,         \
445       typename TTypes<T, 4, int>::ConstTensor in,                       \
446       typename TTypes<T, 4, int>::Tensor out);                          \
447   extern template struct TransformFilter<GPUDevice, T, int, 4>;         \
448   template <>                                                           \
449   void PadInput<GPUDevice, T, int, 4>::operator()(                      \
450       const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in,   \
451       const std::array<int, 2>& padding_left,                           \
452       const std::array<int, 2>& padding_right,                          \
453       typename TTypes<T, 4, int>::Tensor out, TensorFormat data_format, \
454       const T& padding_value);                                          \
455   extern template struct PadInput<GPUDevice, T, int, 4>;
456 
457 DECLARE_GPU_SPEC(float);
458 DECLARE_GPU_SPEC(Eigen::half);
459 DECLARE_GPU_SPEC(double);
460 #undef DECLARE_GPU_SPEC
461 
462 template <>
463 void SpatialConvolutionBackwardInputFunc<GPUDevice, int32>::operator()(
464     const GPUDevice&, typename TTypes<int32, 4>::Tensor,
465     typename TTypes<int32, 4>::ConstTensor,
466     typename TTypes<int32, 4>::ConstTensor, Eigen::DenseIndex,
467     Eigen::DenseIndex, Eigen::DenseIndex, Eigen::DenseIndex);
468 extern template struct SpatialConvolutionBackwardInputFunc<GPUDevice, int32>;
469 
470 template <>
471 void SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
472     GPUDevice, int32>::operator()(const GPUDevice&,
473                                   typename TTypes<int32, 4>::Tensor,
474                                   typename TTypes<int32, 4>::ConstTensor,
475                                   typename TTypes<int32, 4>::ConstTensor,
476                                   Eigen::DenseIndex, Eigen::DenseIndex,
477                                   Eigen::DenseIndex, Eigen::DenseIndex,
478                                   Eigen::DenseIndex, Eigen::DenseIndex,
479                                   Eigen::DenseIndex, Eigen::DenseIndex);
480 extern template struct SpatialConvolutionBackwardInputWithExplicitPaddingFunc<
481     GPUDevice, int32>;
482 
483 }  // namespace functor
484 
485 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
486                             .Device(DEVICE_GPU)
487                             .TypeConstraint<double>("T")
488                             .HostMemory("input_sizes"),
489                         Conv2DBackpropInputOp<GPUDevice, double>);
490 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
491                             .Device(DEVICE_GPU)
492                             .TypeConstraint<float>("T")
493                             .HostMemory("input_sizes"),
494                         Conv2DBackpropInputOp<GPUDevice, float>);
495 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
496                             .Device(DEVICE_GPU)
497                             .TypeConstraint<Eigen::half>("T")
498                             .HostMemory("input_sizes"),
499                         Conv2DBackpropInputOp<GPUDevice, Eigen::half>);
500 REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropInput")
501                             .Device(DEVICE_GPU)
502                             .TypeConstraint<int32>("T")
503                             .HostMemory("input_sizes"),
504                         Conv2DBackpropInputOp<GPUDevice, int32>);
505 
506 // To be used inside depthwise_conv_grad_op.cc.
507 // TODO(reedwm): Move this and the definition to depthwise_conv_grad_op.cc.
508 template struct LaunchConv2DBackpropInputOp<GPUDevice, float>;
509 template struct LaunchConv2DBackpropInputOp<GPUDevice, Eigen::half>;
510 template struct LaunchConv2DBackpropInputOp<GPUDevice, double>;
511 
512 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
513 
514 }  // namespace tensorflow
515