xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/halide_helpers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2try:
3    import halide as hl  # type: ignore[import-untyped, import-not-found]
4except ImportError:
5    hl = None
6
7PHILOX_N_ROUNDS_DEFAULT = 10  # Default number of rounds for philox
8
9if hl is not None:
10    PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9)
11    PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85)
12    PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53)
13    PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57)
14else:
15    PHILOX_KEY_A_U32 = None
16    PHILOX_KEY_B_U32 = None
17    PHILOX_ROUND_A_U32 = None
18    PHILOX_ROUND_B_U32 = None
19
20
21def _pair_uniform_to_normal(u1, u2):
22    """Box-Muller transform"""
23    u1 = hl.max(hl.f32(1.0e-7), u1)
24    th = hl.f32(6.283185307179586) * u2
25    r = hl.sqrt(hl.f32(-2.0) * hl.log(u1))
26    return r * hl.cos(th), r * hl.sin(th)
27
28
29def _uint_to_uniform_float(x):
30    """
31    Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
32    """
33
34    # TODO:
35    # conditions can be simplified
36    # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
37    # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132.
38    assert x.type() == hl.UInt(32) or x.type() == hl.Int(32)
39    x = hl.cast(hl.Int(32), x)
40    scale = hl.f64(4.6566127342e-10)
41    x = hl.select(x < 0, -x - 1, x)
42    return x * scale
43
44
45def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds):
46    def umulhi(a, b):
47        a = hl.cast(hl.UInt(64), a)
48        b = hl.cast(hl.UInt(64), b)
49        return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF))
50
51    for _ in range(n_rounds):
52        _c0, _c2 = c0, c2
53
54        c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0
55        c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1
56        c1 = PHILOX_ROUND_B_U32 * _c2
57        c3 = PHILOX_ROUND_A_U32 * _c0
58        # raise key
59        k0 = k0 + PHILOX_KEY_A_U32
60        k1 = k1 + PHILOX_KEY_B_U32
61
62    return c0, c1, c2, c3
63
64
65def halide_philox(seed, c0, c1, c2, c3, n_rounds):
66    seed = hl.cast(hl.UInt(64), seed)
67
68    assert c0.type().bits() == 32
69
70    seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF))
71    seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF))
72
73    return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)
74
75
76def randint4x(seed, offset, n_rounds):
77    offset = hl.cast(hl.UInt(32), offset)
78    _0 = hl.u32(0)
79    return halide_philox(seed, offset, _0, _0, _0, n_rounds)
80
81
82def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
83    i1, i2, i3, i4 = randint4x(seed, offset, n_rounds)
84    u1 = _uint_to_uniform_float(i1)
85    u2 = _uint_to_uniform_float(i2)
86    u3 = _uint_to_uniform_float(i3)
87    u4 = _uint_to_uniform_float(i4)
88    return u1, u2, u3, u4
89
90
91def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
92    ret, _, _, _ = randint4x(seed, offset, n_rounds)
93    return ret
94
95
96def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
97    source = randint(seed, offset, n_rounds)
98    return _uint_to_uniform_float(source)
99
100
101def randn(seed, offset):
102    i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
103    u1 = _uint_to_uniform_float(i1)
104    u2 = _uint_to_uniform_float(i2)
105    n1, _ = _pair_uniform_to_normal(u1, u2)
106    return n1
107
108
109def randint64(seed, offset, low, high):
110    r0, r1, r2, r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
111    r0 = hl.cast(hl.UInt(64), r0)
112    r1 = hl.cast(hl.UInt(64), r1)
113
114    result = r0 | (r1 << 32)
115    size = high - low
116    result = result % hl.cast(hl.UInt(64), size)
117    result = hl.cast(hl.Int(64), result) + low
118    return result
119