1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for determinsitic depthwise convolutional operations.""" 16 17from tensorflow.python.eager import backprop 18from tensorflow.python.framework import config as tf_config 19from tensorflow.python.framework import dtypes 20from tensorflow.python.framework import random_seed 21from tensorflow.python.framework import test_util 22from tensorflow.python.kernel_tests.nn_ops import depthwise_conv_op_base 23from tensorflow.python.ops import nn_impl 24from tensorflow.python.ops import random_ops 25# The following imports are required to register the gradient functions. 26from tensorflow.python.ops.nn_grad import _DepthwiseConv2dNativeBackpropFilterGrad # pylint: disable=unused-import 27from tensorflow.python.ops.nn_grad import _DepthwiseConv2dNativeBackpropInputGrad # pylint: disable=unused-import 28from tensorflow.python.platform import test 29 30 31@test_util.run_all_without_tensor_float_32("Uses matmul") 32class DepthwiseConv2DDeterministicTest( 33 depthwise_conv_op_base.DepthwiseConv2DBase): 34 """Test determinism-related functionality of tf.nn.depthwise_conv2d.""" 35 36 def _genParams(self, 37 use_cudnn=False, 38 data_format="NHWC", 39 dtype=dtypes.float32, 40 seed=123): 41 random_seed.set_seed(seed) 42 batch_size = 2 # no interaction over batch, so make small 43 if use_cudnn: 44 # When op-determinism is not enabled, one input channel, plus a 45 # cuDNN-supported filter size and number of output channels will result 46 # in cuDNN being used for both backprop-to-input and backprop-to-filter on 47 # cuDNN 7.6.3 and higher. When op-determnism is enabled, cuDNN is always 48 # used for backprop-to-filter. 49 input_channels = 1 50 else: 51 input_channels = 2 # no interaction over channels, so make small 52 input_height = 500 53 input_width = 1000 54 if data_format == "NHWC": 55 input_shape = (batch_size, input_height, input_width, input_channels) 56 else: # "NCHW" 57 input_shape = (batch_size, input_channels, input_height, input_width) 58 input_data = random_ops.random_normal(input_shape, dtype=dtype) 59 # The following filter size results in nondeterminism being exercised in 60 # cuDNN backprop (when determinism is not enabled) to both input and filter 61 # as well as in the specialized (non-cuDNN) depthwise backprop to filter. 62 filter_height = 7 63 filter_width = 7 64 channel_multiplier = 10 65 filter_shape = (filter_height, filter_width, input_channels, 66 channel_multiplier) 67 filter_data = random_ops.random_normal(filter_shape, dtype=dtype) 68 strides = [1, 1, 1, 1] 69 padding = "SAME" 70 output_height = input_height # because same padding 71 output_width = input_width # because same padding 72 output_channels = input_channels * channel_multiplier 73 if data_format == "NHWC": 74 output_shape = (batch_size, output_height, output_width, output_channels) 75 else: # "NCHW" 76 output_shape = (batch_size, output_channels, output_height, output_width) 77 return input_data, filter_data, strides, padding, output_shape 78 79 def _testForwardDeterminismCase(self, 80 use_cudnn=False, 81 data_format="NHWC", 82 dtype=dtypes.float32): 83 for seed in range(5): 84 p = self._genParams(use_cudnn, data_format, dtype, seed=seed) 85 input_data, filter_data, strides, padding, _ = p 86 87 result_a = nn_impl.depthwise_conv2d_v2(input_data, filter_data, strides, 88 padding, data_format) 89 result_b = nn_impl.depthwise_conv2d_v2(input_data, filter_data, strides, 90 padding, data_format) 91 92 self.assertAllEqual(result_a, result_b) 93 94 @test_util.run_gpu_only 95 def testForwardDeterminismGPU(self): 96 for use_cudnn in [False, True]: 97 for data_format in ["NHWC", "NCHW"]: 98 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 99 self._testForwardDeterminismCase(use_cudnn, data_format, dtype=dtype) 100 101 def testForwardDeterminismCPU(self): 102 if tf_config.list_physical_devices("GPU"): 103 self.skipTest("Test only runs when there is no GPU") 104 data_format = "NHWC" # CPU does not implement NCHW version of op 105 for dtype in [dtypes.bfloat16.as_numpy_dtype, dtypes.float32, 106 dtypes.float64]: 107 self._testForwardDeterminismCase(data_format=data_format, dtype=dtype) 108 109 def _testBackwardDeterminismCase(self, 110 using_gpu=False, 111 use_cudnn=False, 112 data_format="NHWC", 113 dtype=dtypes.float32): 114 p = self._genParams(use_cudnn, data_format, dtype, seed=123) 115 input_data, filter_data, strides, padding, output_shape = p 116 117 def Gradients(upstream_gradients): 118 with backprop.GradientTape() as tape: 119 tape.watch(input_data) 120 tape.watch(filter_data) 121 op_output = nn_impl.depthwise_conv2d_v2(input_data, filter_data, 122 strides, padding, data_format) 123 gradient_injector_output = op_output * upstream_gradients 124 return tape.gradient(gradient_injector_output, [input_data, filter_data]) 125 126 # Test only two seeds, since testing takes a long time 127 for seed in (987, 988): 128 upstream_gradients = random_ops.random_normal( 129 output_shape, dtype=dtype, seed=seed) 130 input_gradients_a, filter_gradients_a = Gradients(upstream_gradients) 131 input_gradients_b, filter_gradients_b = Gradients(upstream_gradients) 132 self.assertAllEqual(input_gradients_a, input_gradients_b) 133 self.assertAllEqual(filter_gradients_a, filter_gradients_b) 134 135 @test_util.run_gpu_only 136 def testBackwardDeterminismGPU(self): 137 using_gpu = True 138 for use_cudnn in [False, True]: 139 for data_format in ["NHWC", "NCHW"]: 140 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 141 self._testBackwardDeterminismCase(using_gpu, use_cudnn, data_format, 142 dtype) 143 144 def testBackwardDeterminismCPU(self): 145 if tf_config.list_physical_devices("GPU"): 146 self.skipTest("Test only runs when there is no GPU") 147 data_format = "NHWC" # CPU does not implement NCHW version of op 148 for dtype in [dtypes.bfloat16.as_numpy_dtype, dtypes.float32, 149 dtypes.float64]: 150 self._testBackwardDeterminismCase(data_format=data_format, dtype=dtype) 151 152 153if __name__ == "__main__": 154 # The op-determinism setting can be enabled and disabled on-the-fly. 155 # However, if cuDNN convolution is used (as it is for these tests) then its 156 # setting at the time will influence which algorithm for a particular layer 157 # configuration is cached (independently for XLA and non-XLA operation). 158 # 159 # The tests in this file must be run under a separate test.main from the 160 # tests in depthwise_conv_op_test.py to prevent caching the selection of 161 # nondeterminsitic algorithms, which would cause the tests defined in this 162 # file to fail. 163 # 164 # Also because of this caching, the tests defined in depthwise_conv_op_base.py 165 # should be run with and without op-determinism enabled in separate files. 166 # 167 # TODO(duncanriach): Implement cuDNN auto-tuning cache invalidation and 168 # and execute when op-determinism setting is changed. 169 tf_config.enable_op_determinism() 170 test.main() 171