1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for fused batch norm operations.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20from tensorflow.compiler.tests import test_utils 21from tensorflow.compiler.tests import xla_test 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_nn_ops 25from tensorflow.python.ops import gradient_checker 26from tensorflow.python.ops import nn 27from tensorflow.python.platform import test 28 29DATA_FORMATS = ( 30 ("_data_format_NHWC", "NHWC"), 31 ("_data_format_NCHW", "NCHW"), 32) 33 34DATA_FORMATS_AND_AVG_FACTORS = ( 35 ("_data_format_NHWC_no_averaging", "NHWC", 1.0), 36 ("_data_format_NHWC_averaging", "NHWC", 0.6), 37 ("_data_format_NCHW_no_averaging", "NCHW", 1.0), 38 ("_data_format_NCHW_averaging", "NCHW", 0.6), 39) 40 41 42class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase): 43 44 def _reference_training(self, x, scale, offset, old_mean, old_var, epsilon, 45 exponential_avg_factor, data_format): 46 if data_format != "NHWC": 47 raise ValueError("data_format must be NHWC, got %s." % data_format) 48 x_square = x * x 49 x_square_sum = np.sum(x_square, (0, 1, 2)) 50 x_sum = np.sum(x, axis=(0, 1, 2)) 51 element_count = np.size(x) / int(np.shape(x)[-1]) 52 mean = x_sum / element_count 53 var = x_square_sum / element_count - mean * mean 54 factor = element_count / max(element_count - 1, 1) 55 corrected_var = var * factor 56 normalized = (x - mean) / np.sqrt(var + epsilon) 57 if exponential_avg_factor != 1.0: 58 mean = (1.0 - 59 exponential_avg_factor) * old_mean + exponential_avg_factor * mean 60 corrected_var = (1.0 - exponential_avg_factor 61 ) * old_var + exponential_avg_factor * corrected_var 62 return (normalized * scale + offset), mean, var, corrected_var 63 64 def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format): 65 # Use the following formulas to calculate gradients: 66 # grad_scale = 67 # sum(grad_y * (x - mean)) * rsqrt(var + epsilon) 68 # 69 # grad_offset = sum(output_y) 70 # 71 # grad_x = 72 # 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) - 73 # (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon)) 74 if data_format != "NHWC": 75 raise ValueError("data_format must be NHWC, got %s." % data_format) 76 grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) - 77 (x - mean) * np.mean(grad_y * 78 (x - mean), axis=(0, 1, 2)) / 79 (var + epsilon)) / np.sqrt(var + epsilon) 80 grad_scale = np.sum( 81 grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2)) 82 grad_offset = np.sum(grad_y, axis=(0, 1, 2)) 83 return grad_x, grad_scale, grad_offset 84 85 @parameterized.named_parameters(*DATA_FORMATS) 86 def testInference(self, data_format): 87 channel = 3 88 x_shape = [2, 2, 6, channel] 89 scale_shape = [channel] 90 x_val = np.random.random_sample(x_shape).astype(np.float32) 91 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 92 offset_val = np.random.random_sample(scale_shape).astype(np.float32) 93 epsilon = 0.001 94 exponential_avg_factor = 1.0 95 data_format_src = "NHWC" 96 y_ref, mean_ref, var_ref, _ = self._reference_training( 97 x_val, scale_val, offset_val, None, None, epsilon, 98 exponential_avg_factor, data_format_src) 99 100 with self.session() as sess, self.test_scope(): 101 # To avoid constant folding 102 x_val_converted = test_utils.ConvertBetweenDataFormats( 103 x_val, data_format_src, data_format) 104 y_ref_converted = test_utils.ConvertBetweenDataFormats( 105 y_ref, data_format_src, data_format) 106 107 t_val = array_ops.placeholder( 108 np.float32, shape=x_val_converted.shape, name="x") 109 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 110 offset = array_ops.placeholder( 111 np.float32, shape=scale_shape, name="offset") 112 y, mean, variance = nn.fused_batch_norm( 113 t_val, 114 scale, 115 offset, 116 mean=mean_ref, 117 variance=var_ref, 118 epsilon=epsilon, 119 data_format=data_format, 120 is_training=False) 121 122 y_val, _, _ = sess.run([y, mean, variance], { 123 t_val: x_val_converted, 124 scale: scale_val, 125 offset: offset_val 126 }) 127 self.assertAllClose(y_val, y_ref_converted, atol=1e-3) 128 129 def _testLearning(self, use_gradient_checker, data_format, 130 exponential_avg_factor): 131 channel = 3 132 x_shape = [2, 2, 6, channel] 133 scale_shape = [channel] 134 x_val = np.random.random_sample(x_shape).astype(np.float32) 135 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 136 offset_val = np.random.random_sample(scale_shape).astype(np.float32) 137 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 138 var_val_corr = np.random.random_sample(scale_shape).astype(np.float32) 139 epsilon = 0.001 140 data_format_src = "NHWC" 141 # When in training mode, fused_batchnorm applies an implicit Bessel's 142 # correction. So we have to use the corrected variance here, as well. 143 y_ref, mean_ref, _, var_ref_corr = self._reference_training( 144 x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon, 145 exponential_avg_factor, data_format_src) 146 147 with self.session() as sess, self.test_scope(): 148 # To avoid constant folding 149 x_val_converted = test_utils.ConvertBetweenDataFormats( 150 x_val, data_format_src, data_format) 151 y_ref_converted = test_utils.ConvertBetweenDataFormats( 152 y_ref, data_format_src, data_format) 153 154 t_val = array_ops.placeholder( 155 np.float32, shape=x_val_converted.shape, name="x") 156 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 157 offset = array_ops.placeholder( 158 np.float32, shape=scale_shape, name="offset") 159 if exponential_avg_factor == 1.0: 160 old_mean = None 161 old_var = None 162 else: 163 old_mean = array_ops.placeholder( 164 np.float32, shape=scale_shape, name="old_mean") 165 old_var = array_ops.placeholder( 166 np.float32, shape=scale_shape, name="old_var") 167 y, mean, var = nn.fused_batch_norm( 168 t_val, 169 scale, 170 offset, 171 mean=old_mean, 172 variance=old_var, 173 epsilon=epsilon, 174 exponential_avg_factor=exponential_avg_factor, 175 data_format=data_format, 176 is_training=True) 177 if exponential_avg_factor == 1.0: 178 feed_dict = { 179 t_val: x_val_converted, 180 scale: scale_val, 181 offset: offset_val, 182 } 183 else: 184 feed_dict = { 185 t_val: x_val_converted, 186 scale: scale_val, 187 offset: offset_val, 188 old_mean: mean_val, 189 old_var: var_val_corr 190 } 191 # Check gradient. 192 if use_gradient_checker: 193 err = gradient_checker.compute_gradient_error( 194 t_val, 195 x_val_converted.shape, 196 y, 197 x_val_converted.shape, 198 extra_feed_dict=feed_dict) 199 self.assertLess(err, 1e-3) 200 201 y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict) 202 self.assertAllClose(y_tf, y_ref_converted, atol=1e-3) 203 self.assertAllClose(mean_tf, mean_ref, atol=1e-3) 204 self.assertAllClose(var_tf, var_ref_corr, atol=1e-3) 205 206 @parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS) 207 def testLearning(self, data_format, exponential_avg_factor): 208 self._testLearning(False, data_format, exponential_avg_factor) 209 210 @parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS) 211 def testLearningWithGradientChecker(self, data_format, 212 exponential_avg_factor): 213 self._testLearning(True, data_format, exponential_avg_factor) 214 215 @parameterized.named_parameters(*DATA_FORMATS) 216 def testGradientTraining(self, data_format): 217 # disable_mlir_bridge for GPUs as there is no legalization for GPU with 218 # MLIR. 219 # TODO(b/189039456): Customize FusedBatchNorm legalization for GPU in MLIR. 220 if test_util.is_mlir_bridge_enabled() and self.device == "XLA_GPU": 221 self.skipTest("b/189039456") 222 223 # TODO(b/64270657): Use gradient_checker here in addition to comparing with 224 # this reference implementation. 225 channel = 3 226 x_shape = [2, 2, 6, channel] 227 scale_shape = [channel] 228 grad_val = np.random.random_sample(x_shape).astype(np.float32) 229 x_val = np.random.random_sample(x_shape).astype(np.float32) 230 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 231 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 232 var_val = np.random.random_sample(scale_shape).astype(np.float32) 233 epsilon = 0.001 234 235 # The TensorFlow FusedBatchNormGrad training operation takes two inputs with 236 # implementation defined values. In theory the only correct value these 237 # inputs are the corresponding reserve_space_{1|2} outputs from the 238 # FusedBatchNorm training operation. However, in practice, we rely on the 239 # first one being mean on {C|G}PU, and the second one being variance on CPU 240 # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption 241 # separately). 242 reserve_space_1_val = mean_val 243 if self.device == "XLA_GPU": 244 reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon)) 245 else: 246 reserve_space_2_val = var_val 247 248 data_format_src = "NHWC" 249 grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad( 250 x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src) 251 252 with self.session() as sess, self.test_scope(): 253 grad_val_converted = test_utils.ConvertBetweenDataFormats( 254 grad_val, data_format_src, data_format) 255 x_val_converted = test_utils.ConvertBetweenDataFormats( 256 x_val, data_format_src, data_format) 257 grad_x_ref_converted = test_utils.ConvertBetweenDataFormats( 258 grad_x_ref, data_format_src, data_format) 259 260 grad = array_ops.placeholder( 261 np.float32, shape=x_val_converted.shape, name="grad") 262 x = array_ops.placeholder( 263 np.float32, shape=x_val_converted.shape, name="x") 264 reserve_space_1 = array_ops.placeholder( 265 np.float32, shape=scale_shape, name="reserve_space_1") 266 reserve_space_2 = array_ops.placeholder( 267 np.float32, shape=scale_shape, name="reserve_space_2") 268 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 269 grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( 270 grad, 271 x, 272 scale, 273 reserve_space_1, 274 reserve_space_2, 275 data_format=data_format, 276 is_training=True) 277 278 grad_x_val, grad_scale_val, grad_offset_val = sess.run( 279 [grad_x, grad_scale, grad_offset], { 280 grad: grad_val_converted, 281 x: x_val_converted, 282 reserve_space_1: reserve_space_1_val, 283 reserve_space_2: reserve_space_2_val, 284 scale: scale_val 285 }) 286 287 self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2) 288 self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) 289 self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) 290 291 @parameterized.named_parameters(*DATA_FORMATS) 292 def testGradientInference(self, data_format): 293 # TODO(b/64270657): Use gradient_checker here in addition to comparing with 294 # this reference implementation. 295 channel = 3 296 x_shape = [2, 2, 6, channel] 297 scale_shape = [channel] 298 grad_val = np.random.random_sample(x_shape).astype(np.float32) 299 x_val = np.random.random_sample(x_shape).astype(np.float32) 300 scale_val = np.random.random_sample(scale_shape).astype(np.float32) 301 mean_val = np.random.random_sample(scale_shape).astype(np.float32) 302 var_val = np.random.random_sample(scale_shape).astype(np.float32) 303 data_format_src = "NHWC" 304 305 with self.session() as sess, self.test_scope(): 306 grad_val_converted = test_utils.ConvertBetweenDataFormats( 307 grad_val, data_format_src, data_format) 308 x_val_converted = test_utils.ConvertBetweenDataFormats( 309 x_val, data_format_src, data_format) 310 311 grad = array_ops.placeholder( 312 np.float32, shape=x_val_converted.shape, name="grad") 313 x = array_ops.placeholder( 314 np.float32, shape=x_val_converted.shape, name="x") 315 mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") 316 var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") 317 scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") 318 with self.test_scope(): 319 out = gen_nn_ops.fused_batch_norm_grad( 320 grad, 321 x, 322 scale, 323 mean, 324 var, 325 data_format=data_format, 326 is_training=False) 327 grad_x, grad_scale, grad_offset, _, _ = out 328 329 ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( 330 grad, x, scale, mean, var, data_format=data_format, is_training=False) 331 332 grad_x_val, grad_scale_val, grad_offset_val, = sess.run( 333 [grad_x, grad_scale, grad_offset], { 334 grad: grad_val_converted, 335 x: x_val_converted, 336 mean: mean_val, 337 var: var_val, 338 scale: scale_val 339 }) 340 grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( 341 [ref_x, ref_scale, ref_offset], { 342 grad: grad_val_converted, 343 x: x_val_converted, 344 mean: mean_val, 345 var: var_val, 346 scale: scale_val 347 }) 348 349 self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) 350 self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) 351 self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) 352 353 354if __name__ == "__main__": 355 test.main() 356