1# mypy: allow-untyped-defs 2try: 3 import halide as hl # type: ignore[import-untyped, import-not-found] 4except ImportError: 5 hl = None 6 7PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox 8 9if hl is not None: 10 PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9) 11 PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85) 12 PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53) 13 PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57) 14else: 15 PHILOX_KEY_A_U32 = None 16 PHILOX_KEY_B_U32 = None 17 PHILOX_ROUND_A_U32 = None 18 PHILOX_ROUND_B_U32 = None 19 20 21def _pair_uniform_to_normal(u1, u2): 22 """Box-Muller transform""" 23 u1 = hl.max(hl.f32(1.0e-7), u1) 24 th = hl.f32(6.283185307179586) * u2 25 r = hl.sqrt(hl.f32(-2.0) * hl.log(u1)) 26 return r * hl.cos(th), r * hl.sin(th) 27 28 29def _uint_to_uniform_float(x): 30 """ 31 Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). 32 """ 33 34 # TODO: 35 # conditions can be simplified 36 # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) 37 # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132. 38 assert x.type() == hl.UInt(32) or x.type() == hl.Int(32) 39 x = hl.cast(hl.Int(32), x) 40 scale = hl.f64(4.6566127342e-10) 41 x = hl.select(x < 0, -x - 1, x) 42 return x * scale 43 44 45def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds): 46 def umulhi(a, b): 47 a = hl.cast(hl.UInt(64), a) 48 b = hl.cast(hl.UInt(64), b) 49 return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF)) 50 51 for _ in range(n_rounds): 52 _c0, _c2 = c0, c2 53 54 c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0 55 c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1 56 c1 = PHILOX_ROUND_B_U32 * _c2 57 c3 = PHILOX_ROUND_A_U32 * _c0 58 # raise key 59 k0 = k0 + PHILOX_KEY_A_U32 60 k1 = k1 + PHILOX_KEY_B_U32 61 62 return c0, c1, c2, c3 63 64 65def halide_philox(seed, c0, c1, c2, c3, n_rounds): 66 seed = hl.cast(hl.UInt(64), seed) 67 68 assert c0.type().bits() == 32 69 70 seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF)) 71 seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF)) 72 73 return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) 74 75 76def randint4x(seed, offset, n_rounds): 77 offset = hl.cast(hl.UInt(32), offset) 78 _0 = hl.u32(0) 79 return halide_philox(seed, offset, _0, _0, _0, n_rounds) 80 81 82def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): 83 i1, i2, i3, i4 = randint4x(seed, offset, n_rounds) 84 u1 = _uint_to_uniform_float(i1) 85 u2 = _uint_to_uniform_float(i2) 86 u3 = _uint_to_uniform_float(i3) 87 u4 = _uint_to_uniform_float(i4) 88 return u1, u2, u3, u4 89 90 91def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): 92 ret, _, _, _ = randint4x(seed, offset, n_rounds) 93 return ret 94 95 96def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): 97 source = randint(seed, offset, n_rounds) 98 return _uint_to_uniform_float(source) 99 100 101def randn(seed, offset): 102 i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) 103 u1 = _uint_to_uniform_float(i1) 104 u2 = _uint_to_uniform_float(i2) 105 n1, _ = _pair_uniform_to_normal(u1, u2) 106 return n1 107 108 109def randint64(seed, offset, low, high): 110 r0, r1, r2, r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) 111 r0 = hl.cast(hl.UInt(64), r0) 112 r1 = hl.cast(hl.UInt(64), r1) 113 114 result = r0 | (r1 << 32) 115 size = high - low 116 result = result % hl.cast(hl.UInt(64), size) 117 result = hl.cast(hl.Int(64), result) + low 118 return result 119