xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/fused_batchnorm_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for fused batch norm operations."""
16
17from absl.testing import parameterized
18import numpy as np
19
20from tensorflow.compiler.tests import test_utils
21from tensorflow.compiler.tests import xla_test
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_nn_ops
25from tensorflow.python.ops import gradient_checker
26from tensorflow.python.ops import nn
27from tensorflow.python.platform import test
28
29DATA_FORMATS = (
30    ("_data_format_NHWC", "NHWC"),
31    ("_data_format_NCHW", "NCHW"),
32)
33
34DATA_FORMATS_AND_AVG_FACTORS = (
35    ("_data_format_NHWC_no_averaging", "NHWC", 1.0),
36    ("_data_format_NHWC_averaging", "NHWC", 0.6),
37    ("_data_format_NCHW_no_averaging", "NCHW", 1.0),
38    ("_data_format_NCHW_averaging", "NCHW", 0.6),
39)
40
41
42class FusedBatchNormTest(xla_test.XLATestCase, parameterized.TestCase):
43
44  def _reference_training(self, x, scale, offset, old_mean, old_var, epsilon,
45                          exponential_avg_factor, data_format):
46    if data_format != "NHWC":
47      raise ValueError("data_format must be NHWC, got %s." % data_format)
48    x_square = x * x
49    x_square_sum = np.sum(x_square, (0, 1, 2))
50    x_sum = np.sum(x, axis=(0, 1, 2))
51    element_count = np.size(x) / int(np.shape(x)[-1])
52    mean = x_sum / element_count
53    var = x_square_sum / element_count - mean * mean
54    factor = element_count / max(element_count - 1, 1)
55    corrected_var = var * factor
56    normalized = (x - mean) / np.sqrt(var + epsilon)
57    if exponential_avg_factor != 1.0:
58      mean = (1.0 -
59              exponential_avg_factor) * old_mean + exponential_avg_factor * mean
60      corrected_var = (1.0 - exponential_avg_factor
61                      ) * old_var + exponential_avg_factor * corrected_var
62    return (normalized * scale + offset), mean, var, corrected_var
63
64  def _reference_grad(self, x, grad_y, scale, mean, var, epsilon, data_format):
65    # Use the following formulas to calculate gradients:
66    # grad_scale =
67    #   sum(grad_y * (x - mean)) * rsqrt(var + epsilon)
68    #
69    # grad_offset = sum(output_y)
70    #
71    # grad_x =
72    #   1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
73    #   (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
74    if data_format != "NHWC":
75      raise ValueError("data_format must be NHWC, got %s." % data_format)
76    grad_x = scale * (grad_y - np.mean(grad_y, axis=(0, 1, 2)) -
77                      (x - mean) * np.mean(grad_y *
78                                           (x - mean), axis=(0, 1, 2)) /
79                      (var + epsilon)) / np.sqrt(var + epsilon)
80    grad_scale = np.sum(
81        grad_y * (x - mean) / np.sqrt(var + epsilon), axis=(0, 1, 2))
82    grad_offset = np.sum(grad_y, axis=(0, 1, 2))
83    return grad_x, grad_scale, grad_offset
84
85  @parameterized.named_parameters(*DATA_FORMATS)
86  def testInference(self, data_format):
87    channel = 3
88    x_shape = [2, 2, 6, channel]
89    scale_shape = [channel]
90    x_val = np.random.random_sample(x_shape).astype(np.float32)
91    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
92    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
93    epsilon = 0.001
94    exponential_avg_factor = 1.0
95    data_format_src = "NHWC"
96    y_ref, mean_ref, var_ref, _ = self._reference_training(
97        x_val, scale_val, offset_val, None, None, epsilon,
98        exponential_avg_factor, data_format_src)
99
100    with self.session() as sess, self.test_scope():
101      # To avoid constant folding
102      x_val_converted = test_utils.ConvertBetweenDataFormats(
103          x_val, data_format_src, data_format)
104      y_ref_converted = test_utils.ConvertBetweenDataFormats(
105          y_ref, data_format_src, data_format)
106
107      t_val = array_ops.placeholder(
108          np.float32, shape=x_val_converted.shape, name="x")
109      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
110      offset = array_ops.placeholder(
111          np.float32, shape=scale_shape, name="offset")
112      y, mean, variance = nn.fused_batch_norm(
113          t_val,
114          scale,
115          offset,
116          mean=mean_ref,
117          variance=var_ref,
118          epsilon=epsilon,
119          data_format=data_format,
120          is_training=False)
121
122      y_val, _, _ = sess.run([y, mean, variance], {
123          t_val: x_val_converted,
124          scale: scale_val,
125          offset: offset_val
126      })
127      self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
128
129  def _testLearning(self, use_gradient_checker, data_format,
130                    exponential_avg_factor):
131    channel = 3
132    x_shape = [2, 2, 6, channel]
133    scale_shape = [channel]
134    x_val = np.random.random_sample(x_shape).astype(np.float32)
135    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
136    offset_val = np.random.random_sample(scale_shape).astype(np.float32)
137    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
138    var_val_corr = np.random.random_sample(scale_shape).astype(np.float32)
139    epsilon = 0.001
140    data_format_src = "NHWC"
141    # When in training mode, fused_batchnorm applies an implicit Bessel's
142    # correction. So we have to use the corrected variance here, as well.
143    y_ref, mean_ref, _, var_ref_corr = self._reference_training(
144        x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon,
145        exponential_avg_factor, data_format_src)
146
147    with self.session() as sess, self.test_scope():
148      # To avoid constant folding
149      x_val_converted = test_utils.ConvertBetweenDataFormats(
150          x_val, data_format_src, data_format)
151      y_ref_converted = test_utils.ConvertBetweenDataFormats(
152          y_ref, data_format_src, data_format)
153
154      t_val = array_ops.placeholder(
155          np.float32, shape=x_val_converted.shape, name="x")
156      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
157      offset = array_ops.placeholder(
158          np.float32, shape=scale_shape, name="offset")
159      if exponential_avg_factor == 1.0:
160        old_mean = None
161        old_var = None
162      else:
163        old_mean = array_ops.placeholder(
164            np.float32, shape=scale_shape, name="old_mean")
165        old_var = array_ops.placeholder(
166            np.float32, shape=scale_shape, name="old_var")
167      y, mean, var = nn.fused_batch_norm(
168          t_val,
169          scale,
170          offset,
171          mean=old_mean,
172          variance=old_var,
173          epsilon=epsilon,
174          exponential_avg_factor=exponential_avg_factor,
175          data_format=data_format,
176          is_training=True)
177      if exponential_avg_factor == 1.0:
178        feed_dict = {
179            t_val: x_val_converted,
180            scale: scale_val,
181            offset: offset_val,
182        }
183      else:
184        feed_dict = {
185            t_val: x_val_converted,
186            scale: scale_val,
187            offset: offset_val,
188            old_mean: mean_val,
189            old_var: var_val_corr
190        }
191      # Check gradient.
192      if use_gradient_checker:
193        err = gradient_checker.compute_gradient_error(
194            t_val,
195            x_val_converted.shape,
196            y,
197            x_val_converted.shape,
198            extra_feed_dict=feed_dict)
199        self.assertLess(err, 1e-3)
200
201      y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict)
202      self.assertAllClose(y_tf, y_ref_converted, atol=1e-3)
203      self.assertAllClose(mean_tf, mean_ref, atol=1e-3)
204      self.assertAllClose(var_tf, var_ref_corr, atol=1e-3)
205
206  @parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS)
207  def testLearning(self, data_format, exponential_avg_factor):
208    self._testLearning(False, data_format, exponential_avg_factor)
209
210  @parameterized.named_parameters(*DATA_FORMATS_AND_AVG_FACTORS)
211  def testLearningWithGradientChecker(self, data_format,
212                                      exponential_avg_factor):
213    self._testLearning(True, data_format, exponential_avg_factor)
214
215  @parameterized.named_parameters(*DATA_FORMATS)
216  def testGradientTraining(self, data_format):
217    # disable_mlir_bridge for GPUs as there is no legalization for GPU with
218    # MLIR.
219    # TODO(b/189039456): Customize FusedBatchNorm legalization for GPU in MLIR.
220    if test_util.is_mlir_bridge_enabled() and self.device == "XLA_GPU":
221      self.skipTest("b/189039456")
222
223    # TODO(b/64270657): Use gradient_checker here in addition to comparing with
224    # this reference implementation.
225    channel = 3
226    x_shape = [2, 2, 6, channel]
227    scale_shape = [channel]
228    grad_val = np.random.random_sample(x_shape).astype(np.float32)
229    x_val = np.random.random_sample(x_shape).astype(np.float32)
230    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
231    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
232    var_val = np.random.random_sample(scale_shape).astype(np.float32)
233    epsilon = 0.001
234
235    # The TensorFlow FusedBatchNormGrad training operation takes two inputs with
236    # implementation defined values.  In theory the only correct value these
237    # inputs are the corresponding reserve_space_{1|2} outputs from the
238    # FusedBatchNorm training operation.  However, in practice, we rely on the
239    # first one being mean on {C|G}PU, and the second one being variance on CPU
240    # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption
241    # separately).
242    reserve_space_1_val = mean_val
243    if self.device == "XLA_GPU":
244      reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon))
245    else:
246      reserve_space_2_val = var_val
247
248    data_format_src = "NHWC"
249    grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
250        x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src)
251
252    with self.session() as sess, self.test_scope():
253      grad_val_converted = test_utils.ConvertBetweenDataFormats(
254          grad_val, data_format_src, data_format)
255      x_val_converted = test_utils.ConvertBetweenDataFormats(
256          x_val, data_format_src, data_format)
257      grad_x_ref_converted = test_utils.ConvertBetweenDataFormats(
258          grad_x_ref, data_format_src, data_format)
259
260      grad = array_ops.placeholder(
261          np.float32, shape=x_val_converted.shape, name="grad")
262      x = array_ops.placeholder(
263          np.float32, shape=x_val_converted.shape, name="x")
264      reserve_space_1 = array_ops.placeholder(
265          np.float32, shape=scale_shape, name="reserve_space_1")
266      reserve_space_2 = array_ops.placeholder(
267          np.float32, shape=scale_shape, name="reserve_space_2")
268      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
269      grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
270          grad,
271          x,
272          scale,
273          reserve_space_1,
274          reserve_space_2,
275          data_format=data_format,
276          is_training=True)
277
278      grad_x_val, grad_scale_val, grad_offset_val = sess.run(
279          [grad_x, grad_scale, grad_offset], {
280              grad: grad_val_converted,
281              x: x_val_converted,
282              reserve_space_1: reserve_space_1_val,
283              reserve_space_2: reserve_space_2_val,
284              scale: scale_val
285          })
286
287      self.assertAllClose(grad_x_val, grad_x_ref_converted, atol=1e-2)
288      self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
289      self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
290
291  @parameterized.named_parameters(*DATA_FORMATS)
292  def testGradientInference(self, data_format):
293    # TODO(b/64270657): Use gradient_checker here in addition to comparing with
294    # this reference implementation.
295    channel = 3
296    x_shape = [2, 2, 6, channel]
297    scale_shape = [channel]
298    grad_val = np.random.random_sample(x_shape).astype(np.float32)
299    x_val = np.random.random_sample(x_shape).astype(np.float32)
300    scale_val = np.random.random_sample(scale_shape).astype(np.float32)
301    mean_val = np.random.random_sample(scale_shape).astype(np.float32)
302    var_val = np.random.random_sample(scale_shape).astype(np.float32)
303    data_format_src = "NHWC"
304
305    with self.session() as sess, self.test_scope():
306      grad_val_converted = test_utils.ConvertBetweenDataFormats(
307          grad_val, data_format_src, data_format)
308      x_val_converted = test_utils.ConvertBetweenDataFormats(
309          x_val, data_format_src, data_format)
310
311      grad = array_ops.placeholder(
312          np.float32, shape=x_val_converted.shape, name="grad")
313      x = array_ops.placeholder(
314          np.float32, shape=x_val_converted.shape, name="x")
315      mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
316      var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
317      scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
318      with self.test_scope():
319        out = gen_nn_ops.fused_batch_norm_grad(
320            grad,
321            x,
322            scale,
323            mean,
324            var,
325            data_format=data_format,
326            is_training=False)
327        grad_x, grad_scale, grad_offset, _, _ = out
328
329      ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
330          grad, x, scale, mean, var, data_format=data_format, is_training=False)
331
332      grad_x_val, grad_scale_val, grad_offset_val, = sess.run(
333          [grad_x, grad_scale, grad_offset], {
334              grad: grad_val_converted,
335              x: x_val_converted,
336              mean: mean_val,
337              var: var_val,
338              scale: scale_val
339          })
340      grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run(
341          [ref_x, ref_scale, ref_offset], {
342              grad: grad_val_converted,
343              x: x_val_converted,
344              mean: mean_val,
345              var: var_val,
346              scale: scale_val
347          })
348
349      self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2)
350      self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2)
351      self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3)
352
353
354if __name__ == "__main__":
355  test.main()
356