xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/stateless_random_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for stateless random-number generation ops."""
16
17import functools
18from absl.testing import parameterized
19import numpy as np
20
21from tensorflow.compiler.tests import xla_test
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import config
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.kernel_tests.random import util as \
28random_test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gen_stateless_random_ops_v2
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import stateless_random_ops as stateless
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import test
35
36
37class StatelessRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase):
38  """Test cases for stateless random-number generator operators."""
39
40  def _random_types(self, include_int=False):
41    allowed_types = {dtypes.float64, dtypes.float32, dtypes.bfloat16}
42    if include_int:
43      allowed_types.update({dtypes.int32, dtypes.int64})
44    return self.all_tf_types & allowed_types
45
46  @test_util.run_v2_only
47  def testForcedCompile(self):
48    """Tests whole-function forced-compilation.
49
50    This test checks that stateless_random_* can be used in forced-compilation
51    scenarios (e.g. TPU). The new version of stateless_random_* requires the
52    intermediate tensor `alg` to be compile-time constant, so we need to check
53    that this requirement won't prevent `seed` from depending on variables.
54    """
55    if config.list_logical_devices('TPU'):
56      self.skipTest('To accommodate OSS, experimental_compile support for TPU '
57                    'is not linked in.')
58    # GPU doesn't support int32 variables, so we use int64.
59    v = variables.Variable([1, 2], dtype=dtypes.int64)
60
61    @def_function.function(experimental_compile=True)
62    def f():
63      key, counter = (
64          gen_stateless_random_ops_v2.stateless_random_get_key_counter(
65              seed=math_ops.cast(v.read_value(), dtypes.int32)))
66      alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
67      return gen_stateless_random_ops_v2.stateless_random_normal_v2(
68          shape=[], key=key, counter=counter, alg=alg)
69
70    f()
71
72  @test_util.run_v2_only
73  def testGetKeyCounterAlg(self):
74    seed = [1, 2]
75    key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter(
76        seed)
77    self.assertAllEqual(key.shape, [1])
78    self.assertAllEqual(counter.shape, [2])
79    alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
80    self.assertAllEqual(alg.shape, [])
81
82  @parameterized.named_parameters(
83      ('_%s_%s' % (op_id, alg_id), op, alg_group)  # pylint: disable=g-complex-comprehension
84      for alg_id, alg_group in enumerate([
85          [
86              stateless.Algorithm.PHILOX, stateless.Algorithm.PHILOX.value,
87              'philox'
88          ],
89          [
90              stateless.Algorithm.THREEFRY, stateless.Algorithm.THREEFRY.value,
91              'threefry'
92          ],
93          [
94              stateless.Algorithm.AUTO_SELECT,
95              stateless.Algorithm.AUTO_SELECT.value, 'auto_select', None
96          ],
97      ])
98      for op_id, op in enumerate([
99          stateless.stateless_random_normal,
100          stateless.stateless_truncated_normal,
101          functools.partial(
102              stateless.stateless_random_uniform,
103              dtype=dtypes.uint32,
104              minval=None,
105              maxval=None),
106          functools.partial(
107              stateless.stateless_random_uniform,
108              dtype=dtypes.int32,
109              maxval=100),
110          functools.partial(
111              stateless.stateless_random_uniform, dtype=dtypes.float32),
112      ]))
113  @test_util.run_v2_only
114  def testAlg(self, op, alg_group):
115    """Tests all values of `alg`."""
116    if config.list_logical_devices('TPU') or config.list_logical_devices('GPU'):
117      self.skipTest('Only _cpu tests linked in support for jit_compile on CPU.')
118    seed = [1, 2]
119    shape = [2, 3]
120    outputs = []
121    for alg in alg_group:
122      with ops.device('CPU'):
123        output = def_function.function(jit_compile=True)(op)(
124            shape=shape, seed=seed, alg=alg)
125      self.assertEqual(output.shape, shape)
126      outputs.append(output)
127    x = outputs[0]
128    for y in outputs[1:]:
129      self.assertAllEqual(x, y)
130
131  def testLargeNormal(self):
132    """Tests an OOM bug of StatelessRandomNormalV2 on TPU."""
133    with self.session() as sess, self.test_scope():
134      seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
135      key, counter, alg = (gen_stateless_random_ops_v2.
136                           stateless_random_get_key_counter_alg(seed_t))
137      x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
138          shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
139          alg=alg)
140      y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
141      self.assertAllEqual([1024, 32000], y.shape)
142      key, counter = (gen_stateless_random_ops_v2.
143                      stateless_random_get_key_counter(seed_t))
144      alg = gen_stateless_random_ops_v2.stateless_random_get_alg()
145      x = gen_stateless_random_ops_v2.stateless_random_normal_v2(
146          shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32,
147          alg=alg)
148      y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
149      self.assertAllEqual([1024, 32000], y.shape)
150
151  def testDeterminism(self):
152    # Stateless values should be equal iff the seeds are equal (roughly)
153    with self.session(), self.test_scope():
154      seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
155      seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3  # pylint: disable=g-complex-comprehension
156      for stateless_op in [
157          stateless.stateless_random_uniform, stateless.stateless_random_normal
158      ]:
159        for shape in (), (3,), (2, 5):
160          for dtype in self._random_types():
161            # Skip bfloat16. The result of bfloat16 is truncated from 32-bit
162            # result. With different seeds, the 32-bit results are different,
163            # but the truncated 16-bit results might be the same.
164            if dtype == dtypes.bfloat16:
165              continue
166            pure = stateless_op(shape, seed=seed_t, dtype=dtype)
167            values = [(seed, pure.eval(feed_dict={
168                seed_t: seed
169            })) for seed in seeds]
170            for s0, v0 in values:
171              for s1, v1 in values:
172                self.assertEqual(s0 == s1, np.all(v0 == v1))
173
174  def testRandomUniformIsInRange(self):
175    with self.session() as sess, self.test_scope():
176      for dtype in self._random_types(include_int=True):
177        maxval = 1
178        if dtype.is_integer:
179          maxval = 100
180        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
181        x = stateless.stateless_random_uniform(
182            shape=[1000], seed=seed_t, maxval=maxval, dtype=dtype)
183        y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
184        self.assertTrue(np.all(y >= 0))
185        self.assertTrue(np.all(y < maxval))
186
187  def testDistributionOfStatelessRandomUniform(self):
188    """Use Pearson's Chi-squared test to test for uniformity."""
189    with self.session() as sess, self.test_scope():
190      for dtype in self._random_types(include_int=True):
191        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
192        n = 1000
193        maxval = 1
194        if dtype.is_integer:
195          maxval = 100
196        x = stateless.stateless_random_uniform(
197            shape=[n], seed=seed_t, maxval=maxval, dtype=dtype)
198        y = sess.run(x, {seed_t: [565656, 121212]})
199        # Convert y to float and normalize its value to range [0, 1) when
200        # maxval != 1.
201        y = y.astype(float) / maxval
202        # Tests that the values are distributed amongst 10 bins with equal
203        # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with
204        # p=0.05. This test is probabilistic and would be flaky if the random
205        # seed were not fixed.
206        self.assertLess(random_test_util.chi_squared(y, 10), 16.92)
207
208  def testRandomNormalIsFinite(self):
209    with self.session() as sess, self.test_scope():
210      for dtype in self._random_types():
211        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
212        x = stateless.stateless_random_normal(
213            shape=[10000], seed=seed_t, dtype=dtype)
214        y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
215        self.assertTrue(np.all(np.isfinite(y)))
216
217  def testDistributionOfStatelessRandomNormal(self):
218    """Use Anderson-Darling test to test distribution appears normal."""
219    with self.session() as sess, self.test_scope():
220      for dtype in self._random_types():
221        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
222        n = 1000
223        x = stateless.stateless_random_normal(
224            shape=[n], seed=seed_t, dtype=dtype)
225        y = sess.run(x, {seed_t: [25252, 314159]})
226        # The constant 2.492 is the 5% critical value for the Anderson-Darling
227        # test where the mean and variance are known. This test is probabilistic
228        # so to avoid flakiness the seed is fixed.
229        self.assertLess(
230            random_test_util.anderson_darling(y.astype(float)), 2.492)
231
232  def testTruncatedNormal(self):
233    for dtype in self._random_types():
234      with self.session() as sess, self.test_scope():
235        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
236        n = 10000000
237        x = stateless.stateless_truncated_normal(
238            shape=[n], seed=seed_t, dtype=dtype)
239        y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
240        random_test_util.test_truncated_normal(
241            self.assertEqual, self.assertAllClose, n, y,
242            variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3)
243
244  def _testParameterizedTruncatedNormal(self,
245                                        means,
246                                        stddevs,
247                                        minvals,
248                                        maxvals,
249                                        variance_rtol=None):
250    for dtype in self._random_types():
251      with self.session() as sess, self.test_scope():
252        seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
253        n = int(10e7)
254        x = stateless.stateless_parameterized_truncated_normal(
255            shape=[n],
256            seed=seed_t,
257            means=means,
258            stddevs=stddevs,
259            minvals=minvals,
260            maxvals=maxvals)
261        y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
262        if variance_rtol is None:
263          variance_rtol = 6e-3 if dtype == dtypes.bfloat16 else 1e-3
264        random_test_util.test_truncated_normal(
265            self.assertEqual,
266            self.assertAllClose,
267            n,
268            y,
269            means=means,
270            stddevs=stddevs,
271            minvals=minvals,
272            maxvals=maxvals,
273            mean_atol=1e-3,
274            median_atol=1e-3,
275            variance_rtol=variance_rtol)
276
277  def testParameterizedTruncatedNormalDefault(self):
278    self._testParameterizedTruncatedNormal(0., 1., -2., 2.)
279
280  def testParameterizedTruncatedNormalShifted(self):
281    self._testParameterizedTruncatedNormal(-1., 1., -2., 2.)
282
283  def testParameterizedTruncatedNormalRightTail(self):
284    self._testParameterizedTruncatedNormal(0., 1., 4., 20., variance_rtol=2e-2)
285
286  def testParameterizedTruncatedNormalLeftTail(self):
287    self._testParameterizedTruncatedNormal(
288        0., 1., -20., -4., variance_rtol=5e-2)
289
290  def testParameterizedTruncatedNormalLeftTailTwoSidedBounds(self):
291    self._testParameterizedTruncatedNormal(
292        0., 1., -6., -3., variance_rtol=5e-2)
293
294  def testParameterizedTruncatedNormalSmallStddev(self):
295    self._testParameterizedTruncatedNormal(0., 0.1, 0.05, 0.10)
296
297  def testParameterizedTruncatedNormalBroadcast(self):
298    with self.session() as sess, self.test_scope():
299      seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
300      means = array_ops.zeros([2], dtype=dtypes.float32)
301      stddevs = array_ops.ones([3, 1], dtype=dtypes.float32)
302      minvals = -array_ops.ones([5, 1, 1], dtype=dtypes.float32)
303      maxvals = array_ops.ones([7, 1, 1, 1], dtype=dtypes.float32)
304      shape = [11, 7, 5, 3, 2]
305      x = stateless.stateless_parameterized_truncated_normal(
306          shape=shape,
307          seed=seed_t,
308          means=means,
309          stddevs=stddevs,
310          minvals=minvals,
311          maxvals=maxvals)
312      y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]})
313      self.assertEqual((11, 7, 5, 3, 2), y.shape)
314
315
316class StatelessRandomOpsBenchmark(test.Benchmark):
317  """Microbenchmarks for the stateless random ops."""
318
319  def _benchmarkUniform(self, name, dtype, use_xla_jit):
320
321    def builder_fn():
322      shape = (10, 1000, 1000)
323      seed_var = variables.Variable((312, 456),
324                                    dtype=dtypes.int32,
325                                    name='input')
326      random_t = stateless.stateless_random_uniform(
327          shape, seed=seed_var, dtype=dtype)
328      return '%s.shape%s' % (name, shape), [random_t]
329
330    xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu')
331
332  def benchmarkUniformF32(self):
333    self._benchmarkUniform(
334        'uniform_f32', dtype=dtypes.float32, use_xla_jit=False)
335
336  def benchmarkUniformF64(self):
337    self._benchmarkUniform(
338        'uniform_f64', dtype=dtypes.float64, use_xla_jit=False)
339
340  def benchmarkUniformF32XLA(self):
341    self._benchmarkUniform(
342        'uniform_f32', dtype=dtypes.float32, use_xla_jit=True)
343
344  def benchmarkUniformF64XLA(self):
345    self._benchmarkUniform(
346        'uniform_f64', dtype=dtypes.float64, use_xla_jit=True)
347
348
349if __name__ == '__main__':
350  config.set_soft_device_placement(False)
351  test.main()
352