xref: /aosp_15_r20/external/tensorflow/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for depthwise convolutional operations."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import errors
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gradient_checker
26from tensorflow.python.ops import nn_impl
27from tensorflow.python.ops import nn_ops
28import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
29from tensorflow.python.platform import test
30from tensorflow.python.platform import tf_logging
31
32
33def _DepthwiseConv2dNumpyBasic(x1, x2, strides):
34  """Compute depthwise_conv2d using Numpy.
35
36  This allows use to test TensorFlow's depthwise_conv2d by comparing to the
37  Numpy version.
38
39  Args:
40    x1: The input Numpy array, in NHWC format.
41    x2: The filter Numpy array.
42    strides: A Python list of 4 elements representing the strides.
43
44  Returns:
45    The depthwise conv2d output as a Numpy array.
46  """
47  n, h, w, c = x1.shape
48  fh, fw, c2, o = x2.shape
49  assert c == c2
50  _, sh, sw, _ = strides
51  out_rows = (h - fh + sh) // sh
52  out_cols = (w - fw + sw) // sw
53  out = np.zeros([n, out_rows, out_cols, c * o])
54  for i in range(out_rows):
55    for j in range(out_cols):
56      for k in range(c):
57        start_height = i * sh
58        end_height = start_height + fh
59        start_width = j * sw
60        end_width = start_width + fw
61        # multiplied_slice.shape: (b, fh, fw, o)
62        multiplied_slice = (
63            x1[:, start_height:end_height, start_width:end_width, k, np.newaxis]
64            * x2[:, :, k, :])
65        # Set a slice of b * o elements of 'out'.
66        out[:, i, j, k * o:(k + 1) * o] = np.sum(multiplied_slice, axis=(1, 2))
67  return out
68
69
70def _DepthwiseConv2dNumpy(x1, x2, strides, padding, data_format, dilations):
71  """Compute depthwise_conv2d using Numpy.
72
73  This allows use to test TensorFlow's depthwise_conv2d by comparing to the
74  Numpy version.
75
76  Unlike `_DepthwiseConv2dNumpyBasic`, this supports more advanced features
77  like padding.
78
79  Args:
80    x1: The input Numpy array.
81    x2: The filter Numpy array.
82    strides: A Python list of 4 elements representing the strides.
83    padding: The padding. "SAME", "VALID", or a list of explicit paddings.
84    data_format: "NHWC" or "NCHW".
85    dilations: A list of 2 elements, representing the dilations.
86
87  Returns:
88    The depthwise conv2d as a Numpy array.
89  """
90  if data_format == "NCHW":
91    # Transpose arguments to NHWC format.
92    x1 = np.transpose(x1, (0, 3, 1, 2))
93    strides = [strides[0], strides[3], strides[1], strides[2]]
94    if dilations:
95      dilations = [dilations[0], dilations[3], dilations[1], dilations[2]]
96
97  if dilations:
98    # Dilate the filter so _DepthwiseConv2dNumpyBasic doesn't have to deal with
99    # dilations.
100    fh, fw, c, o = x2.shape
101    new_fh = (fh - 1) * dilations[0] + 1
102    new_fw = (fw - 1) * dilations[1] + 1
103    new_x2 = np.zeros((new_fh, new_fw, c, o))
104    for i in range(fh):
105      for j in range(fw):
106        new_x2[i * dilations[0], j * dilations[1], ::] = x2[i, j, :, :]
107    x2 = new_x2
108
109  # Pad input so _DepthwiseConv2dNumpyBasic doesn't have to deal with padding.
110  if padding == "SAME":
111
112    def PaddingsForDim(input_dim, filter_dim, stride):
113      """Computes paddings for a single dimension."""
114      if input_dim % stride == 0:
115        total_padding = max(filter_dim - stride, 0)
116      else:
117        total_padding = max(filter_dim - (input_dim % stride), 0)
118      pad_before = total_padding // 2
119      pad_after = total_padding - pad_before
120      return pad_before, pad_after
121
122    padding = [(0, 0),
123               PaddingsForDim(x1.shape[1], x2.shape[0], strides[1]),
124               PaddingsForDim(x1.shape[2], x2.shape[1], strides[2]), (0, 0)]
125  elif padding == "VALID":
126    padding = [(0, 0)] * 4
127  x1 = np.pad(x1, padding, "constant")
128
129  y = _DepthwiseConv2dNumpyBasic(x1, x2, strides)
130
131  if data_format == "NCHW":
132    # Transpose back to NCHW format.
133    y = np.transpose(y, (0, 2, 3, 1))
134
135  return y
136
137
138def ConfigsToTest():
139  """Iterator for different convolution shapes, strides and paddings.
140
141  Returns:
142    List of tuples (input_size, filter_size, out_size, stride, padding,
143    dilations), the depthwise convolution parameters.
144  """
145
146  def Config(input_size,
147             filter_size,
148             out_size,
149             stride=1,
150             padding="SAME",
151             dilations=None):
152    return input_size, filter_size, out_size, stride, padding, dilations
153
154  return [
155      Config([4, 5, 5, 48], [1, 1, 48, 2], [4, 5, 5, 96]),
156      Config([4, 8, 8, 84], [1, 3, 84, 1], [4, 8, 8, 84]),
157      Config([4, 17, 17, 48], [3, 1, 48, 4], [4, 17, 17, 192]),
158      Config([4, 9, 27, 8], [3, 3, 8, 1], [4, 9, 27, 8]),
159      Config([4, 31, 31, 7], [3, 3, 7, 1], [4, 31, 31, 7]),
160      Config([4, 35, 35, 2], [5, 5, 2, 1], [4, 35, 35, 2]),
161      Config([4, 147, 147, 2], [3, 3, 2, 8], [4, 49, 49, 16],
162             3,
163             padding="VALID"),
164      Config([3, 299, 299, 3], [3, 2, 3, 8], [3, 150, 150, 24], 2),
165      Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 92, 92, 2], 2),
166      Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 183, 183, 2], dilations=[2,
167                                                                          2]),
168      Config([5, 41, 35, 2], [4, 7, 2, 2], [5, 32, 23, 4],
169             padding="VALID",
170             dilations=[3, 2]),
171  ]
172
173
174def ConfigsToTestExplicit():
175  """Iterator for different convolution shapes, strides and explicit paddings.
176
177  Returns:
178    List of tuples (input_size, filter_size, out_size, stride, padding,
179    dilations), the depthwise convolution parameters.
180  """
181
182  def Config(input_size,
183             filter_size,
184             out_size,
185             stride=1,
186             padding=None,
187             dilations=None):
188    return input_size, filter_size, out_size, stride, padding, dilations
189
190  return [
191      Config([4, 5, 5, 48], [1, 1, 48, 2], [4, 8, 12, 96],
192             padding=[[1, 2], [3, 4]]),
193      Config([4, 1, 1, 3], [3, 3, 3, 2], [4, 29, 39, 6],
194             padding=[[10, 20], [15, 25]]),
195      Config([4, 9, 27, 8], [3, 3, 8, 1], [4, 14, 31, 8],
196             padding=[[3, 4], [4, 2]]),
197      Config([4, 31, 31, 7], [3, 3, 7, 1], [4, 29, 29, 7],
198             padding=[[0, 0], [0, 0]]),
199      Config([3, 299, 299, 3], [3, 2, 3, 8], [3, 150, 153, 24],
200             2,
201             padding=[[1, 2], [3, 5]]),
202      Config([5, 183, 183, 1], [5, 5, 1, 2], [5, 62, 60, 2],
203             3,
204             padding=[[3, 2], [1, 0]]),
205      Config([5, 29, 31, 1], [5, 4, 1, 2], [5, 26, 23, 2],
206             padding=[[3, 2], [1, 0]],
207             dilations=[2, 3]),
208      # These cases test the kernels in depthwise_conv_op_gpu.h which are used
209      # if the input size is small.
210      Config([4, 5, 5, 48], [3, 3, 48, 1], [4, 5, 5, 48],
211             padding=[[0, 2], [0, 2]]),
212      Config([1, 8, 7, 2], [8, 7, 2, 1], [1, 8, 7, 2], padding=[[0, 7], [3,
213                                                                         3]]),
214      Config([2, 4, 3, 2], [3, 2, 2, 1], [2, 4, 3, 2], padding=[[2, 0], [1,
215                                                                         0]]),
216  ]
217
218
219def CheckGradConfigsToTest():
220  """Iterator for different convolution shapes, strides and paddings.
221
222  compute_gradient_error() is very expensive. So the configs should be
223  relatively small.
224
225  Returns:
226    List of tuples (input_size, filter_size, out_size, stride, padding,
227    dilations), the depthwise convolution parameters.
228  """
229
230  def Config(input_size,
231             filter_size,
232             out_size,
233             stride=1,
234             padding="SAME",
235             dilations=None):
236    return input_size, filter_size, out_size, stride, padding, dilations
237
238  return [
239      Config([2, 5, 8, 1], [4, 4, 1, 2], [2, 5, 8, 2]),
240      Config([4, 5, 5, 1], [2, 2, 1, 2], [4, 2, 2, 2], 2, padding="VALID"),
241      Config([2, 4, 4, 2], [3, 1, 2, 2], [2, 4, 4, 4]),
242      Config([1, 15, 15, 2], [1, 3, 2, 1], [1, 15, 15, 2]),
243      Config([2, 15, 16, 1], [3, 3, 1, 2], [2, 5, 5, 2], 3, padding="VALID"),
244      Config([2, 5, 8, 1], [4, 3, 1, 2], [2, 5, 8, 2], dilations=[1, 2]),
245      # These cases test the kernels in depthwise_conv_op_gpu.h which are used
246      # if the input size is small.
247      Config([1, 3, 1, 2], [2, 1, 2, 1], [1, 3, 1, 2]),
248      Config([2, 2, 3, 2], [2, 1, 2, 1], [2, 2, 3, 2]),
249      Config([2, 2, 3, 1], [2, 2, 1, 1], [2, 2, 3, 1]),
250  ]
251
252
253def CheckGradConfigsToTestExplicit():
254  """Iterator for different convolution shapes, strides and explicit paddings.
255
256  compute_gradient_error() is very expensive. So the configs should be
257  relatively small.
258
259  Returns:
260    List of tuples (input_size, filter_size, out_size, stride, padding,
261    dilations), the depthwise convolution parameters.
262  """
263
264  def Config(input_size,
265             filter_size,
266             out_size,
267             stride=1,
268             padding=None,
269             dilations=None):
270    return input_size, filter_size, out_size, stride, padding, dilations
271
272  return [
273      Config([2, 5, 8, 1], [4, 4, 1, 2], [2, 3, 10, 2],
274             padding=[[0, 1], [2, 3]]),
275      Config([4, 5, 5, 1], [2, 2, 1, 2], [4, 4, 5, 2],
276             2,
277             padding=[[3, 1], [5, 0]]),
278      Config([2, 4, 4, 2], [3, 1, 2, 2], [2, 7, 11, 4],
279             padding=[[4, 1], [3, 4]]),
280      Config([1, 15, 15, 2], [1, 3, 2, 1], [1, 18, 23, 2],
281             padding=[[3, 0], [2, 8]]),
282      Config([2, 15, 16, 1], [3, 3, 1, 2], [2, 5, 8, 2],
283             3,
284             padding=[[0, 0], [10, 0]]),
285      Config([2, 5, 8, 1], [3, 4, 1, 2], [2, 5, 10, 2],
286             padding=[[3, 1], [2, 3]],
287             dilations=[2, 1]),
288      # These cases test the kernels in depthwise_conv_op_gpu.h which are used
289      # if the input size is small.
290      Config([2, 4, 3, 2], [3, 2, 2, 1], [2, 4, 3, 2], padding=[[2, 0], [1,
291                                                                         0]]),
292  ]
293
294
295class DepthwiseConv2DBase(test.TestCase):
296  """Base test class for depthwise Conv2D tests."""
297
298  # This tests depthwise_conv2d and depthwise_conv2d_native
299  def _VerifyValues(self,
300                    tensor_in_sizes,
301                    filter_in_sizes,
302                    stride,
303                    padding,
304                    data_type,
305                    use_gpu,
306                    grouped_conv=False,
307                    data_format="NHWC",
308                    dilations=None,
309                    tolerance=None):
310    """Verifies the output values of the convolution function.
311
312    Args:
313      tensor_in_sizes: Input tensor dimensions in [batch, input_rows,
314        input_cols, input_depth].
315      filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols,
316        input_depth, depth_multiplier].
317      stride: Stride.
318      padding: Padding type.
319      data_type: The data type to use.
320      use_gpu: Whether to use GPU.
321      grouped_conv: Whether to use cuDNN 7's grouped convolution.
322      data_format: The data_format of the input. "NHWC" or "NCHW".
323      dilations: A list of 2 elements, representing the dilations.
324      tolerance: The absolute and relative tolarance when verifying the output.
325    """
326    input_size = 1
327    filter_size = 1
328    for s in tensor_in_sizes:
329      input_size *= s
330    for s in filter_in_sizes:
331      filter_size *= s
332    # Initializes the input and filter tensor with numbers incrementing to 1.0.
333    x1 = [f * 1.0 / input_size for f in range(1, input_size + 1)]
334    x1 = np.array(x1).reshape(tensor_in_sizes)
335    x2 = [f * 1.0 / filter_size for f in range(1, filter_size + 1)]
336    x2 = np.array(x2).reshape(filter_in_sizes)
337    # Compute reference result
338    strides = [1, stride, stride, 1]
339    if isinstance(padding, list):
340      padding = [(0, 0)] + padding + [(0, 0)]
341    np_result = _DepthwiseConv2dNumpy(x1, x2, strides, padding, "NHWC",
342                                      dilations)
343
344    ops.reset_default_graph()
345    graph = ops.get_default_graph()
346    with self.session(graph=graph, use_gpu=use_gpu) as sess:
347      tolerance = tolerance or {
348          dtypes.float16: 4e-2,
349          dtypes.float32: 1e-5,
350          dtypes.float64: 1e-12,
351          dtypes.bfloat16: 1e-2,
352      }[data_type]
353
354      t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
355      t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=data_type)
356
357      if data_format == "NCHW":
358        # Transpose from NHWC input to NCHW
359        # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
360        t1 = array_ops.transpose(t1, [0, 3, 1, 2])
361        strides = [1, 1, stride, stride]
362        if isinstance(padding, list):
363          padding = [padding[0], padding[3], padding[1], padding[2]]
364
365      # depthwise_conv2d_native does not support dilations except on TPUs.
366      if dilations is None:
367        with sess.graph._kernel_label_map(  # pylint: disable=protected-access
368            {"DepthwiseConv2dNative": "cudnn_grouped_convolution"}
369            if grouped_conv else {}):
370          conv_native = nn_ops.depthwise_conv2d_native(
371              t1, t2, strides=strides, data_format=data_format, padding=padding)
372
373        if data_format == "NCHW":
374          # Transpose back from NCHW to NHWC
375          conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
376
377        try:
378          # The Numpy array from calling depthwise_conv2d_native
379          native_result = self.evaluate(conv_native)
380        except errors.InvalidArgumentError as e:
381          # Grouped convolution kernel is only registered for cuDNN 7. Silently
382          # return when we are running on an earlier version or without GPU.
383          if ("No OpKernel was registered to support Op "
384              "'DepthwiseConv2dNative'") in e.message:
385            tf_logging.warn("Skipping grouped convolution test")
386            return
387          raise e
388
389      conv_interface = nn_impl.depthwise_conv2d(
390          t1,
391          t2,
392          strides=strides,
393          padding=padding,
394          data_format=data_format,
395          dilations=dilations)
396      if data_format == "NCHW":
397        # Transpose back from NCHW to NHWC
398        conv_interface = array_ops.transpose(conv_interface, [0, 2, 3, 1])
399
400      # The Numpy array from calling depthwise_conv2d
401      interface_result = self.evaluate(conv_interface)
402
403    if dilations is None:
404      self.assertAllClose(
405          native_result, np_result, atol=tolerance, rtol=tolerance)
406    self.assertAllClose(
407        interface_result, np_result, atol=tolerance, rtol=tolerance)
408
409  @test_util.run_v1_only("b/120545219")
410  @test_util.run_cuda_only
411  def testDepthwiseConv2DCudnn(self):
412    for index, (input_size, filter_size, _, stride, padding,
413                dilations) in enumerate(ConfigsToTest()):
414      # The CuDNN depthwise conv is turned on only when input/output is NCHW and
415      # float16(half). See cudnn release note 7.6.3.
416      tf_logging.info(
417          "Testing DepthwiseConv2DCudnn, %dth config: %r * %r, stride: %d, "
418          "padding: %s", index, input_size, filter_size, stride, padding)
419      data_type = dtypes.float16
420      self._VerifyValues(
421          input_size,
422          filter_size,
423          stride,
424          padding,
425          data_type,
426          use_gpu=True,
427          data_format="NCHW",
428          dilations=dilations)
429
430  @test_util.run_v1_only("b/120545219")
431  def testDepthwiseConv2D(self):
432    for index, (input_size, filter_size, _, stride, padding,
433                dilations) in enumerate(ConfigsToTest()):
434      tf_logging.info(
435          "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
436          "%s", index, input_size, filter_size, stride, padding)
437      # double datatype is currently not supported for convolution ops
438      # on the ROCm platform
439      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
440      for data_type in ([dtypes.float32] + optional_float64):
441        tf_logging.info("Testing without grouped_conv")
442        tolerance = 1e-4 if data_type == dtypes.float32 else 1e-12
443        self._VerifyValues(
444            input_size,
445            filter_size,
446            stride,
447            padding,
448            data_type,
449            use_gpu=True,
450            dilations=dilations,
451            tolerance=tolerance)
452        tf_logging.info("Testing with grouped_conv")
453        self._VerifyValues(
454            input_size,
455            filter_size,
456            stride,
457            padding,
458            data_type,
459            use_gpu=True,
460            grouped_conv=True,
461            dilations=dilations,
462            tolerance=tolerance)
463
464  @test_util.run_v1_only("b/120545219")
465  def testDepthwiseConv2DWithUnknownShape(self):
466    # GitHub issue 22110.
467    if not test.is_gpu_available():
468      return
469    with self.session():
470      x = array_ops.placeholder(dtypes.float32)
471      f = np.ones([1, 1, 1, 1], np.float32)
472      v = nn_impl.depthwise_conv2d(
473          x, f, [1, 1, 1, 1], "VALID", rate=[2, 1], data_format="NCHW")
474      self.assertAllEqual(
475          np.ones([1, 1, 1, 1], np.float32),
476          v.eval(feed_dict={x: np.ones([1, 1, 1, 1], np.float32)}))
477
478  @test_util.run_v1_only("b/120545219")
479  def testDepthwiseConv2DFormat(self):
480    if not test.is_gpu_available():
481      return
482
483    for index, (input_size, filter_size, _, stride, padding,
484                dilations) in enumerate(ConfigsToTest()):
485      tf_logging.info(
486          "Testing DepthwiseConv2DFormat, %dth config: %r * %r, stride: %d, "
487          "padding: %s", index, input_size, filter_size, stride, padding)
488      # double datatype is currently not supported for convolution ops
489      # on the ROCm platform
490      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
491      for data_type in ([dtypes.float32] + optional_float64):
492        tolerance = 1e-4 if data_type == dtypes.float32 else 1e-12
493        self._VerifyValues(
494            input_size,
495            filter_size,
496            stride,
497            padding,
498            data_type,
499            use_gpu=True,
500            data_format="NCHW",
501            dilations=dilations,
502            tolerance=tolerance)
503
504  @test_util.run_v1_only("b/120545219")
505  def testDepthwiseConv2DExplicit(self):
506    for index, (input_size, filter_size, _, stride, padding,
507                dilations) in enumerate(ConfigsToTestExplicit()):
508      tf_logging.info(
509          "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: "
510          "%s", index, input_size, filter_size, stride, padding)
511      # double datatype is currently not supported for convolution ops
512      # on the ROCm platform
513      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
514      data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"]
515      for data_type in [dtypes.float16, dtypes.float32] + optional_float64:
516        for data_format in data_formats:
517          self._VerifyValues(
518              input_size,
519              filter_size,
520              stride,
521              padding,
522              data_type,
523              use_gpu=True,
524              data_format=data_format,
525              dilations=dilations)
526
527
528# This is testing against hand calculated results.
529
530  def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
531                        expected, use_gpu):
532    """Verifies the output values of the depthwise convolution function.
533
534    Args:
535      tensor_in_sizes: Input tensor dimensions in [batch, input_rows,
536        input_cols, input_depth].
537      filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols,
538        input_depth, depth_multiplier].
539      stride: Stride.
540      padding: Padding type.
541      expected: An array containing the expected operation outputs.
542      use_gpu: Whether to use GPU.
543    """
544    total_size_1 = 1
545    total_size_2 = 1
546    for s in tensor_in_sizes:
547      total_size_1 *= s
548    for s in filter_in_sizes:
549      total_size_2 *= s
550    # Initializes the input tensor with array containing incrementing
551    # numbers from 1.
552    x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
553    x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
554    with self.cached_session(use_gpu=use_gpu) as sess:
555      t1 = constant_op.constant(x1, shape=tensor_in_sizes)
556      t1.set_shape(tensor_in_sizes)
557      t2 = constant_op.constant(x2, shape=filter_in_sizes)
558      conv = nn_ops.depthwise_conv2d_native(
559          t1, t2, strides=[1, stride, stride, 1], padding=padding)
560      value = self.evaluate(conv)
561    tf_logging.info("value = %r", value)
562    self.assertArrayNear(expected, np.ravel(value), 1e-5)
563    self.assertShapeEqual(value, conv)
564
565  def testConv2D2x2Filter(self):
566    # The inputs look like this (it's a 3 x 2 matrix, each of depth 2):
567    #
568    # [ (1.0, 2.0), (3.0,  4.0), ( 5.0,  6.0) ]
569    # [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ]
570    #  We can view this as two inputs
571    #
572    #  input depth 0:
573    #
574    #  [ 1.0,  3.0,  5.0 ]
575    #  [ 7.0,  9.0, 11.0 ]
576    #
577    #  input depth 1:
578    #
579    #  [ 2.0,  4.0,  6.0 ]
580    #  [ 8.0, 10.0, 12.0 ]
581    #
582    # The filter looks like this (it has two 2 x 2 patches, each generating 2
583    # depths):
584    #
585    #  filter #0:
586    #
587    #  [ (1.0,  3.0), ( 5.0,  7.0)]
588    #  [ (9.0, 11.0), (13.0, 15.0)]
589    #
590    #  filter #1:
591    #
592    #  [ ( 2.0,  4.0), ( 6.0,  8.0)]
593    #  [ (10.0, 12.0), (14.0, 16.0)]
594    #
595    # So the outputs are:
596    #
597    # (position 0, 0: in_depth 0, output_depth 0 -- using filter #0)
598    #  1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196
599    # (position 0, 0: in_depth 0, output_depth 1 -- using filter #1)
600    #  1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216
601    # (position 0, 0: in_depth 1, output_depth 2 -- using filter #0)
602    #  2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272
603    # (position 0, 0: in_depth 1, output_depth 3 -- using filter #1)
604    #  2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296
605    #
606    # (position 1, 0: in_depth 0, output_depth 0 -- using filter #0)
607    #  3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252
608    # (position 1, 0: in_depth 0, output_depth 1 -- using filter #1)
609    #  3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280
610    # (position 1, 0: in_depth 1, output_depth 2 -- using filter #0)
611    #  4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344
612    # (position 1, 0: in_depth 1, output_depth 3 -- using filter #1)
613    #  4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376
614    expected_output = [196, 216, 272, 296, 252, 280, 344, 376]
615    self._VerifyHandValues(
616        tensor_in_sizes=[1, 2, 3, 2],
617        filter_in_sizes=[2, 2, 2, 2],
618        stride=1,
619        padding="VALID",
620        expected=expected_output,
621        use_gpu=False)
622
623    self._VerifyHandValues(
624        tensor_in_sizes=[1, 2, 3, 2],
625        filter_in_sizes=[2, 2, 2, 2],
626        stride=1,
627        padding="VALID",
628        expected=expected_output,
629        use_gpu=True)
630
631  # Gradient checkers. This tests depthwise gradient computations for both
632  # BackpropFilter and BackpropInput by comparing gradients computed by the
633  # depthwise gradient ops with the gradients computed numerically (details can
634  # be found in the compute_gradient_error().
635  # Note this check is very expensive so the input should not be too big.
636  def _ConstructAndTestGradient(self,
637                                input_shape,
638                                filter_shape,
639                                output_shape,
640                                stride,
641                                padding,
642                                data_type,
643                                test_input,
644                                use_gpu,
645                                grouped_conv=False,
646                                data_format="NHWC",
647                                dilations=None):
648    input_size = 1
649    for x in input_shape:
650      input_size *= x
651    filter_size = 1
652    for x in filter_shape:
653      filter_size *= x
654    input_data = [x * 1.0 / input_size for x in range(0, input_size)]
655    input_np = np.array(input_data).reshape(input_shape)
656    filter_data = [x * 1.0 / filter_size for x in range(0, filter_size)]
657    filter_np = np.array(filter_data).reshape(filter_shape)
658    ops.reset_default_graph()
659    graph = ops.get_default_graph()
660    with self.session(graph=graph, use_gpu=use_gpu) as sess:
661      tolerance = {
662          dtypes.float16: 4e-0,
663          dtypes.float32: 8e-4,
664          dtypes.float64: 1e-12,
665      }[data_type]
666
667      input_tensor = constant_op.constant(
668          input_np, shape=input_shape, dtype=data_type, name="input")
669      filter_tensor = constant_op.constant(
670          filter_np, shape=filter_shape, dtype=data_type, name="filter")
671
672      native_input = input_tensor
673      strides = [1, stride, stride, 1]
674      if isinstance(padding, list):
675        padding = [(0, 0)] + padding + [(0, 0)]
676      if data_format == "NCHW":
677        # Transpose from NHWC input to NCHW
678        # Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
679        native_input = array_ops.transpose(input_tensor, [0, 3, 1, 2])
680        input_shape = [
681            input_shape[0], input_shape[3], input_shape[1], input_shape[2]
682        ]
683        output_shape = [
684            output_shape[0], output_shape[3], output_shape[1], output_shape[2]
685        ]
686        strides = [1, 1, stride, stride]
687        if isinstance(padding, list):
688          padding = [padding[0], padding[3], padding[1], padding[2]]
689
690      with sess.graph._kernel_label_map({  # pylint: disable=protected-access,g-long-ternary
691          "DepthwiseConv2dNative": "cudnn_grouped_convolution",
692          "DepthwiseConv2dNativeBackpropInput": "cudnn_grouped_convolution",
693          "DepthwiseConv2dNativeBackpropFilter": "cudnn_grouped_convolution",
694      } if grouped_conv else {}):
695        depthwise_conv2d = nn_impl.depthwise_conv2d(
696            native_input,
697            filter_tensor,
698            strides,
699            padding,
700            data_format=data_format,
701            dilations=dilations,
702            name="depthwise_conv2d")
703
704      self.assertEqual(output_shape, depthwise_conv2d.get_shape())
705
706      try:
707        if test_input:
708          err = gradient_checker.compute_gradient_error(native_input,
709                                                        input_shape,
710                                                        depthwise_conv2d,
711                                                        output_shape)
712        else:
713          err = gradient_checker.compute_gradient_error(filter_tensor,
714                                                        filter_shape,
715                                                        depthwise_conv2d,
716                                                        output_shape)
717      except errors.InvalidArgumentError as e:
718        # TODO(xjun): Tests depend on error messages could be brittle.
719        # Grouped convolution kernel is only registered for cuDNN 7. Silently
720        # return when we are running on an earlier version or without GPU.
721        if grouped_conv and ("No OpKernel was registered to support Op "
722                             "'DepthwiseConv2dNative'") in e.message:
723          tf_logging.warn("Skipping grouped convolution test")
724          return
725        raise e
726
727      tf_logging.info(
728          "data_type: %r, use_gpu: %r, grouped_conv: %r, error = %f", data_type,
729          use_gpu, grouped_conv, err)
730      self.assertLess(err, tolerance)
731
732  @test_util.run_v1_only("b/120545219")
733  @test_util.run_cuda_only
734  def testDepthwiseConv2DInputGradCudnn(self):
735    for index, (input_size, filter_size, output_size, stride, padding,
736                dilations) in enumerate(CheckGradConfigsToTest()):
737      # The CuDNN depthwise conv (input gradient) is turned on only when
738      # stride = 1, input/output is NCHW and float16(half). See cudnn release
739      # note 7.6.3.
740      if stride != 1:
741        continue
742      tf_logging.info(
743          "Testing DepthwiseConv2DInputGradCudnn, %dth config: %r * %r, "
744          "stride: %d, padding: %s", index, input_size, filter_size, stride,
745          padding)
746      data_type = dtypes.float16
747      self._ConstructAndTestGradient(
748          input_size,
749          filter_size,
750          output_size,
751          stride,
752          padding,
753          data_type,
754          test_input=True,
755          use_gpu=True,
756          data_format="NCHW",
757          dilations=dilations)
758
759  @test_util.run_v1_only("b/120545219")
760  def testDepthwiseConv2DInputGrad(self):
761    for index, (input_size, filter_size, output_size, stride, padding,
762                dilations) in enumerate(CheckGradConfigsToTest()):
763      tf_logging.info(
764          "Testing DepthwiseConv2DInputGrad, %dth config: %r * %r, stride: %d, "
765          "padding: %s", index, input_size, filter_size, stride, padding)
766      # double datatype is currently not supported for convolution ops
767      # on the ROCm platform
768      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
769      for data_type in ([dtypes.float32] + optional_float64):
770        self._ConstructAndTestGradient(
771            input_size,
772            filter_size,
773            output_size,
774            stride,
775            padding,
776            data_type,
777            test_input=True,
778            use_gpu=True,
779            dilations=dilations)
780        self._ConstructAndTestGradient(
781            input_size,
782            filter_size,
783            output_size,
784            stride,
785            padding,
786            data_type,
787            test_input=True,
788            use_gpu=True,
789            grouped_conv=True,
790            dilations=dilations)
791
792  @test_util.run_v1_only("b/120545219")
793  def testDepthwiseConv2DInputGradFormat(self):
794    if not test.is_gpu_available():
795      return
796
797    for index, (input_size, filter_size, output_size, stride, padding,
798                dilations) in enumerate(CheckGradConfigsToTest()):
799      tf_logging.info(
800          "Testing DepthwiseConv2DInputGradFormat, %dth config: %r * %r, "
801          "stride: %d, padding: %s", index, input_size, filter_size, stride,
802          padding)
803      # double datatype is currently not supported for convolution ops
804      # on the ROCm platform
805      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
806      for data_type in ([dtypes.float32] + optional_float64):
807        self._ConstructAndTestGradient(
808            input_size,
809            filter_size,
810            output_size,
811            stride,
812            padding,
813            data_type,
814            test_input=True,
815            use_gpu=True,
816            data_format="NCHW",
817            dilations=dilations)
818
819  @test_util.run_v1_only("b/120545219")
820  def testDepthwiseConv2DInputGradExplicit(self):
821    for index, (input_size, filter_size, output_size, stride, padding,
822                dilations) in enumerate(CheckGradConfigsToTestExplicit()):
823      tf_logging.info(
824          "Testing DepthwiseConv2DInputGradExplicit, %dth config: %r * %r, "
825          "stride: %d, padding: %s", index, input_size, filter_size, stride,
826          padding)
827      # double datatype is currently not supported for convolution ops
828      # on the ROCm platform
829      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
830      data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"]
831      for data_type in [dtypes.float16, dtypes.float32] + optional_float64:
832        for data_format in data_formats:
833          self._ConstructAndTestGradient(
834              input_size,
835              filter_size,
836              output_size,
837              stride,
838              padding,
839              data_type,
840              test_input=True,
841              use_gpu=True,
842              data_format=data_format,
843              dilations=dilations)
844
845  @test_util.run_v1_only("b/120545219")
846  @test_util.run_cuda_only
847  def testDepthwiseConv2DFilterGradCudnn(self):
848    for index, (input_size, filter_size, output_size, stride, padding,
849                dilations) in enumerate(CheckGradConfigsToTest()):
850      # The CuDNN depthwise conv (filter gradient) is turned on only when
851      # input/output is float16(half). See cudnn release note 7.6.3.
852      tf_logging.info(
853          "Testing DepthwiseConv2DFilterGradCudnn, %dth config: %r * %r, "
854          "stride: %d, padding: %s", index, input_size, filter_size, stride,
855          padding)
856      data_type = dtypes.float16
857      self._ConstructAndTestGradient(
858          input_size,
859          filter_size,
860          output_size,
861          stride,
862          padding,
863          data_type,
864          test_input=False,
865          use_gpu=True,
866          data_format="NCHW",
867          dilations=dilations)
868      self._ConstructAndTestGradient(
869          input_size,
870          filter_size,
871          output_size,
872          stride,
873          padding,
874          data_type,
875          test_input=False,
876          use_gpu=True,
877          data_format="NHWC",
878          dilations=dilations)
879
880  @test_util.run_v1_only("b/120545219")
881  def testDepthwiseConv2DFilterGrad(self):
882    for index, (input_size, filter_size, output_size, stride, padding,
883                dilations) in enumerate(CheckGradConfigsToTest()):
884      tf_logging.info(
885          "Testing DepthwiseConv2DFilterGrad, %dth config: %r * %r, stride: "
886          "%d, padding: %s", index, input_size, filter_size, stride, padding)
887      # double datatype is currently not supported for convolution ops
888      # on the ROCm platform
889      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
890      for data_type in ([dtypes.float16, dtypes.float32] + optional_float64):
891        self._ConstructAndTestGradient(
892            input_size,
893            filter_size,
894            output_size,
895            stride,
896            padding,
897            data_type,
898            test_input=False,
899            use_gpu=True,
900            dilations=dilations)
901
902  @test_util.run_v1_only("b/120545219")
903  def testDepthwiseConv2DFilterGradFormat(self):
904    if not test.is_gpu_available():
905      return
906
907    for index, (input_size, filter_size, output_size, stride, padding,
908                dilations) in enumerate(CheckGradConfigsToTest()):
909      tf_logging.info(
910          "Testing DepthwiseConv2DFilterGradFormat, %dth config: %r * %r, "
911          "stride: %d, padding: %s", index, input_size, filter_size, stride,
912          padding)
913      # double datatype is currently not supported for convolution ops
914      # on the ROCm platform
915      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
916      for data_type in ([dtypes.float32] + optional_float64):
917        self._ConstructAndTestGradient(
918            input_size,
919            filter_size,
920            output_size,
921            stride,
922            padding,
923            data_type,
924            test_input=False,
925            use_gpu=True,
926            data_format="NCHW",
927            dilations=dilations)
928
929  @test_util.run_v1_only("b/120545219")
930  def testDepthwiseConv2DFilterGradExplicit(self):
931    for index, (input_size, filter_size, output_size, stride, padding,
932                dilations) in enumerate(CheckGradConfigsToTestExplicit()):
933      tf_logging.info(
934          "Testing DepthwiseConv2DFilterGradExplicit, %dth config: %r * %r, "
935          "stride: %d, padding: %s", index, input_size, filter_size, stride,
936          padding)
937      # double datatype is currently not supported for convolution ops
938      # on the ROCm platform
939      optional_float64 = [] if test.is_built_with_rocm() else [dtypes.float64]
940      data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"]
941      for data_type in [dtypes.float16, dtypes.float32] + optional_float64:
942        for data_format in data_formats:
943          self._ConstructAndTestGradient(
944              input_size,
945              filter_size,
946              output_size,
947              stride,
948              padding,
949              data_type,
950              test_input=False,
951              use_gpu=True,
952              data_format=data_format,
953              dilations=dilations)
954
955  def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
956                            stride, padding, dtype):
957    x1 = np.random.rand(*filter_sizes).astype(dtype)
958    x2 = np.random.rand(*output_sizes).astype(dtype)
959    if isinstance(padding, list):
960      padding = [(0, 0)] + padding + [(0, 0)]
961
962    def _GetVal(use_gpu):
963      with self.cached_session(use_gpu=use_gpu):
964        t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
965        t1 = constant_op.constant(x1, shape=filter_sizes)
966        t2 = constant_op.constant(x2, shape=output_sizes)
967        backprop = nn_ops.depthwise_conv2d_native_backprop_input(
968            t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
969        ret = self.evaluate(backprop)
970        self.assertShapeEqual(ret, backprop)
971        return ret
972
973    gpu_value = _GetVal(use_gpu=True)
974    cpu_value = _GetVal(use_gpu=False)
975    self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
976
977  @test_util.run_gpu_only
978  def testDepthwiseConv2DInputGradCompare(self):
979    for index, (input_size, filter_size, output_size, stride, padding,
980                dilations) in enumerate(ConfigsToTest()):
981      if dilations:
982        continue
983      tf_logging.info(
984          "Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
985          "stride: %d, padding: %s", index, input_size, filter_size, stride,
986          padding)
987      self._CompareBackpropInput(input_size, filter_size, output_size, stride,
988                                 padding, "float32")
989      # double datatype is currently not supported for convolution ops
990      # on the ROCm platform
991      if test.is_built_with_rocm():
992        continue
993      self._CompareBackpropInput(input_size, filter_size, output_size, stride,
994                                 padding, "float64")
995
996  @test_util.run_gpu_only
997  def testDepthwiseConv2DInputGradExplicitCompare(self):
998    for index, (input_size, filter_size, output_size, stride, padding,
999                dilations) in enumerate(ConfigsToTestExplicit()):
1000      if dilations:
1001        continue
1002      tf_logging.info(
1003          "Testing DepthwiseConv2DInputGradCompare, %dth config: %r * %r, "
1004          "stride: %d, padding: %s", index, input_size, filter_size, stride,
1005          padding)
1006      self._CompareBackpropInput(input_size, filter_size, output_size, stride,
1007                                 padding, "float32")
1008      # double datatype is currently not supported for convolution ops
1009      # on the ROCm platform
1010      if test.is_built_with_rocm():
1011        continue
1012      self._CompareBackpropInput(input_size, filter_size, output_size, stride,
1013                                 padding, "float64")
1014
1015  def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
1016                             stride, padding, dtype):
1017    x0 = np.random.rand(*input_sizes).astype(dtype)
1018    x2 = np.random.rand(*output_sizes).astype(dtype)
1019    padding_nhwc = padding
1020    padding_nchw = padding
1021    if isinstance(padding, list):
1022      padding_nhwc = [(0, 0)] + padding + [(0, 0)]
1023      padding_nchw = [(0, 0)] + [(0, 0)] + padding
1024
1025    def _GetVal(use_gpu, data_format="NHWC"):
1026      with self.cached_session(use_gpu=use_gpu):
1027        t0 = constant_op.constant(x0, shape=input_sizes)
1028        t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
1029        t2 = constant_op.constant(x2, shape=output_sizes)
1030        strides = [1, stride, stride, 1]
1031        padding = padding_nhwc
1032        if data_format == "NCHW":
1033          t0 = array_ops.transpose(t0, [0, 3, 1, 2])
1034          t2 = array_ops.transpose(t2, [0, 3, 1, 2])
1035          strides = [1, 1, stride, stride]
1036          padding = padding_nchw
1037        backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
1038            t0,
1039            t1,
1040            t2,
1041            strides=strides,
1042            padding=padding,
1043            data_format=data_format)
1044        ret = self.evaluate(backprop)
1045        self.assertShapeEqual(ret, backprop)
1046        return ret
1047
1048    cpu_value = _GetVal(use_gpu=False)
1049    for data_format in ["NHWC", "NCHW"]:
1050      gpu_value = _GetVal(use_gpu=True, data_format=data_format)
1051      self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
1052
1053  @test_util.run_gpu_only
1054  def testDepthwiseConv2DFilterGradCompare(self):
1055    for index, (input_size, filter_size, output_size, stride, padding,
1056                dilations) in enumerate(ConfigsToTest()):
1057      if dilations:
1058        continue
1059      tf_logging.info(
1060          "Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
1061          "stride: %d, padding: %s", index, input_size, filter_size, stride,
1062          padding)
1063      self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
1064                                  padding, "float32")
1065      # double datatype is currently not supported for convolution ops
1066      # on the ROCm platform
1067      if test.is_built_with_rocm():
1068        continue
1069      self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
1070                                  padding, "float64")
1071
1072  @test_util.run_gpu_only
1073  def testDepthwiseConv2DFilterGradExplicitCompare(self):
1074    for index, (input_size, filter_size, output_size, stride, padding,
1075                dilations) in enumerate(ConfigsToTestExplicit()):
1076      if dilations:
1077        continue
1078      tf_logging.info(
1079          "Testing DepthwiseConv2DFilterGradCompare, %dth config: %r * %r, "
1080          "stride: %d, padding: %s", index, input_size, filter_size, stride,
1081          padding)
1082      self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
1083                                  padding, "float32")
1084      # double datatype is currently not supported for convolution ops
1085      # on the ROCm platform
1086      if test.is_built_with_rocm():
1087        continue
1088      self._CompareBackpropFilter(input_size, filter_size, output_size, stride,
1089                                  padding, "float64")
1090