1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for stateless random-number generation ops.""" 16 17import functools 18from absl.testing import parameterized 19import numpy as np 20 21from tensorflow.compiler.tests import xla_test 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import config 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.kernel_tests.random import util as \ 28random_test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gen_stateless_random_ops_v2 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import stateless_random_ops as stateless 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import test 35 36 37class StatelessRandomOpsTest(xla_test.XLATestCase, parameterized.TestCase): 38 """Test cases for stateless random-number generator operators.""" 39 40 def _random_types(self, include_int=False): 41 allowed_types = {dtypes.float64, dtypes.float32, dtypes.bfloat16} 42 if include_int: 43 allowed_types.update({dtypes.int32, dtypes.int64}) 44 return self.all_tf_types & allowed_types 45 46 @test_util.run_v2_only 47 def testForcedCompile(self): 48 """Tests whole-function forced-compilation. 49 50 This test checks that stateless_random_* can be used in forced-compilation 51 scenarios (e.g. TPU). The new version of stateless_random_* requires the 52 intermediate tensor `alg` to be compile-time constant, so we need to check 53 that this requirement won't prevent `seed` from depending on variables. 54 """ 55 if config.list_logical_devices('TPU'): 56 self.skipTest('To accommodate OSS, experimental_compile support for TPU ' 57 'is not linked in.') 58 # GPU doesn't support int32 variables, so we use int64. 59 v = variables.Variable([1, 2], dtype=dtypes.int64) 60 61 @def_function.function(experimental_compile=True) 62 def f(): 63 key, counter = ( 64 gen_stateless_random_ops_v2.stateless_random_get_key_counter( 65 seed=math_ops.cast(v.read_value(), dtypes.int32))) 66 alg = gen_stateless_random_ops_v2.stateless_random_get_alg() 67 return gen_stateless_random_ops_v2.stateless_random_normal_v2( 68 shape=[], key=key, counter=counter, alg=alg) 69 70 f() 71 72 @test_util.run_v2_only 73 def testGetKeyCounterAlg(self): 74 seed = [1, 2] 75 key, counter = gen_stateless_random_ops_v2.stateless_random_get_key_counter( 76 seed) 77 self.assertAllEqual(key.shape, [1]) 78 self.assertAllEqual(counter.shape, [2]) 79 alg = gen_stateless_random_ops_v2.stateless_random_get_alg() 80 self.assertAllEqual(alg.shape, []) 81 82 @parameterized.named_parameters( 83 ('_%s_%s' % (op_id, alg_id), op, alg_group) # pylint: disable=g-complex-comprehension 84 for alg_id, alg_group in enumerate([ 85 [ 86 stateless.Algorithm.PHILOX, stateless.Algorithm.PHILOX.value, 87 'philox' 88 ], 89 [ 90 stateless.Algorithm.THREEFRY, stateless.Algorithm.THREEFRY.value, 91 'threefry' 92 ], 93 [ 94 stateless.Algorithm.AUTO_SELECT, 95 stateless.Algorithm.AUTO_SELECT.value, 'auto_select', None 96 ], 97 ]) 98 for op_id, op in enumerate([ 99 stateless.stateless_random_normal, 100 stateless.stateless_truncated_normal, 101 functools.partial( 102 stateless.stateless_random_uniform, 103 dtype=dtypes.uint32, 104 minval=None, 105 maxval=None), 106 functools.partial( 107 stateless.stateless_random_uniform, 108 dtype=dtypes.int32, 109 maxval=100), 110 functools.partial( 111 stateless.stateless_random_uniform, dtype=dtypes.float32), 112 ])) 113 @test_util.run_v2_only 114 def testAlg(self, op, alg_group): 115 """Tests all values of `alg`.""" 116 if config.list_logical_devices('TPU') or config.list_logical_devices('GPU'): 117 self.skipTest('Only _cpu tests linked in support for jit_compile on CPU.') 118 seed = [1, 2] 119 shape = [2, 3] 120 outputs = [] 121 for alg in alg_group: 122 with ops.device('CPU'): 123 output = def_function.function(jit_compile=True)(op)( 124 shape=shape, seed=seed, alg=alg) 125 self.assertEqual(output.shape, shape) 126 outputs.append(output) 127 x = outputs[0] 128 for y in outputs[1:]: 129 self.assertAllEqual(x, y) 130 131 def testLargeNormal(self): 132 """Tests an OOM bug of StatelessRandomNormalV2 on TPU.""" 133 with self.session() as sess, self.test_scope(): 134 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 135 key, counter, alg = (gen_stateless_random_ops_v2. 136 stateless_random_get_key_counter_alg(seed_t)) 137 x = gen_stateless_random_ops_v2.stateless_random_normal_v2( 138 shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32, 139 alg=alg) 140 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 141 self.assertAllEqual([1024, 32000], y.shape) 142 key, counter = (gen_stateless_random_ops_v2. 143 stateless_random_get_key_counter(seed_t)) 144 alg = gen_stateless_random_ops_v2.stateless_random_get_alg() 145 x = gen_stateless_random_ops_v2.stateless_random_normal_v2( 146 shape=[1024, 32000], key=key, counter=counter, dtype=dtypes.float32, 147 alg=alg) 148 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 149 self.assertAllEqual([1024, 32000], y.shape) 150 151 def testDeterminism(self): 152 # Stateless values should be equal iff the seeds are equal (roughly) 153 with self.session(), self.test_scope(): 154 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 155 seeds = [(x, y) for x in range(-2, 3) for y in range(-2, 3)] * 3 # pylint: disable=g-complex-comprehension 156 for stateless_op in [ 157 stateless.stateless_random_uniform, stateless.stateless_random_normal 158 ]: 159 for shape in (), (3,), (2, 5): 160 for dtype in self._random_types(): 161 # Skip bfloat16. The result of bfloat16 is truncated from 32-bit 162 # result. With different seeds, the 32-bit results are different, 163 # but the truncated 16-bit results might be the same. 164 if dtype == dtypes.bfloat16: 165 continue 166 pure = stateless_op(shape, seed=seed_t, dtype=dtype) 167 values = [(seed, pure.eval(feed_dict={ 168 seed_t: seed 169 })) for seed in seeds] 170 for s0, v0 in values: 171 for s1, v1 in values: 172 self.assertEqual(s0 == s1, np.all(v0 == v1)) 173 174 def testRandomUniformIsInRange(self): 175 with self.session() as sess, self.test_scope(): 176 for dtype in self._random_types(include_int=True): 177 maxval = 1 178 if dtype.is_integer: 179 maxval = 100 180 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 181 x = stateless.stateless_random_uniform( 182 shape=[1000], seed=seed_t, maxval=maxval, dtype=dtype) 183 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 184 self.assertTrue(np.all(y >= 0)) 185 self.assertTrue(np.all(y < maxval)) 186 187 def testDistributionOfStatelessRandomUniform(self): 188 """Use Pearson's Chi-squared test to test for uniformity.""" 189 with self.session() as sess, self.test_scope(): 190 for dtype in self._random_types(include_int=True): 191 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 192 n = 1000 193 maxval = 1 194 if dtype.is_integer: 195 maxval = 100 196 x = stateless.stateless_random_uniform( 197 shape=[n], seed=seed_t, maxval=maxval, dtype=dtype) 198 y = sess.run(x, {seed_t: [565656, 121212]}) 199 # Convert y to float and normalize its value to range [0, 1) when 200 # maxval != 1. 201 y = y.astype(float) / maxval 202 # Tests that the values are distributed amongst 10 bins with equal 203 # probability. 16.92 is the Chi^2 value for 9 degrees of freedom with 204 # p=0.05. This test is probabilistic and would be flaky if the random 205 # seed were not fixed. 206 self.assertLess(random_test_util.chi_squared(y, 10), 16.92) 207 208 def testRandomNormalIsFinite(self): 209 with self.session() as sess, self.test_scope(): 210 for dtype in self._random_types(): 211 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 212 x = stateless.stateless_random_normal( 213 shape=[10000], seed=seed_t, dtype=dtype) 214 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 215 self.assertTrue(np.all(np.isfinite(y))) 216 217 def testDistributionOfStatelessRandomNormal(self): 218 """Use Anderson-Darling test to test distribution appears normal.""" 219 with self.session() as sess, self.test_scope(): 220 for dtype in self._random_types(): 221 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 222 n = 1000 223 x = stateless.stateless_random_normal( 224 shape=[n], seed=seed_t, dtype=dtype) 225 y = sess.run(x, {seed_t: [25252, 314159]}) 226 # The constant 2.492 is the 5% critical value for the Anderson-Darling 227 # test where the mean and variance are known. This test is probabilistic 228 # so to avoid flakiness the seed is fixed. 229 self.assertLess( 230 random_test_util.anderson_darling(y.astype(float)), 2.492) 231 232 def testTruncatedNormal(self): 233 for dtype in self._random_types(): 234 with self.session() as sess, self.test_scope(): 235 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 236 n = 10000000 237 x = stateless.stateless_truncated_normal( 238 shape=[n], seed=seed_t, dtype=dtype) 239 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 240 random_test_util.test_truncated_normal( 241 self.assertEqual, self.assertAllClose, n, y, 242 variance_rtol=6e-3 if dtype == dtypes.bfloat16 else 1e-3) 243 244 def _testParameterizedTruncatedNormal(self, 245 means, 246 stddevs, 247 minvals, 248 maxvals, 249 variance_rtol=None): 250 for dtype in self._random_types(): 251 with self.session() as sess, self.test_scope(): 252 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 253 n = int(10e7) 254 x = stateless.stateless_parameterized_truncated_normal( 255 shape=[n], 256 seed=seed_t, 257 means=means, 258 stddevs=stddevs, 259 minvals=minvals, 260 maxvals=maxvals) 261 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 262 if variance_rtol is None: 263 variance_rtol = 6e-3 if dtype == dtypes.bfloat16 else 1e-3 264 random_test_util.test_truncated_normal( 265 self.assertEqual, 266 self.assertAllClose, 267 n, 268 y, 269 means=means, 270 stddevs=stddevs, 271 minvals=minvals, 272 maxvals=maxvals, 273 mean_atol=1e-3, 274 median_atol=1e-3, 275 variance_rtol=variance_rtol) 276 277 def testParameterizedTruncatedNormalDefault(self): 278 self._testParameterizedTruncatedNormal(0., 1., -2., 2.) 279 280 def testParameterizedTruncatedNormalShifted(self): 281 self._testParameterizedTruncatedNormal(-1., 1., -2., 2.) 282 283 def testParameterizedTruncatedNormalRightTail(self): 284 self._testParameterizedTruncatedNormal(0., 1., 4., 20., variance_rtol=2e-2) 285 286 def testParameterizedTruncatedNormalLeftTail(self): 287 self._testParameterizedTruncatedNormal( 288 0., 1., -20., -4., variance_rtol=5e-2) 289 290 def testParameterizedTruncatedNormalLeftTailTwoSidedBounds(self): 291 self._testParameterizedTruncatedNormal( 292 0., 1., -6., -3., variance_rtol=5e-2) 293 294 def testParameterizedTruncatedNormalSmallStddev(self): 295 self._testParameterizedTruncatedNormal(0., 0.1, 0.05, 0.10) 296 297 def testParameterizedTruncatedNormalBroadcast(self): 298 with self.session() as sess, self.test_scope(): 299 seed_t = array_ops.placeholder(dtypes.int32, shape=[2]) 300 means = array_ops.zeros([2], dtype=dtypes.float32) 301 stddevs = array_ops.ones([3, 1], dtype=dtypes.float32) 302 minvals = -array_ops.ones([5, 1, 1], dtype=dtypes.float32) 303 maxvals = array_ops.ones([7, 1, 1, 1], dtype=dtypes.float32) 304 shape = [11, 7, 5, 3, 2] 305 x = stateless.stateless_parameterized_truncated_normal( 306 shape=shape, 307 seed=seed_t, 308 means=means, 309 stddevs=stddevs, 310 minvals=minvals, 311 maxvals=maxvals) 312 y = sess.run(x, {seed_t: [0x12345678, 0xabcdef1]}) 313 self.assertEqual((11, 7, 5, 3, 2), y.shape) 314 315 316class StatelessRandomOpsBenchmark(test.Benchmark): 317 """Microbenchmarks for the stateless random ops.""" 318 319 def _benchmarkUniform(self, name, dtype, use_xla_jit): 320 321 def builder_fn(): 322 shape = (10, 1000, 1000) 323 seed_var = variables.Variable((312, 456), 324 dtype=dtypes.int32, 325 name='input') 326 random_t = stateless.stateless_random_uniform( 327 shape, seed=seed_var, dtype=dtype) 328 return '%s.shape%s' % (name, shape), [random_t] 329 330 xla_test.Benchmark(self, builder_fn, use_xla_jit=use_xla_jit, device='cpu') 331 332 def benchmarkUniformF32(self): 333 self._benchmarkUniform( 334 'uniform_f32', dtype=dtypes.float32, use_xla_jit=False) 335 336 def benchmarkUniformF64(self): 337 self._benchmarkUniform( 338 'uniform_f64', dtype=dtypes.float64, use_xla_jit=False) 339 340 def benchmarkUniformF32XLA(self): 341 self._benchmarkUniform( 342 'uniform_f32', dtype=dtypes.float32, use_xla_jit=True) 343 344 def benchmarkUniformF64XLA(self): 345 self._benchmarkUniform( 346 'uniform_f64', dtype=dtypes.float64, use_xla_jit=True) 347 348 349if __name__ == '__main__': 350 config.set_soft_device_placement(False) 351 test.main() 352