xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/triton_helpers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-decorators
2# mypy: allow-untyped-defs
3import triton
4import triton.language as tl
5
6
7# In the latest triton, math functions were shuffled around into different modules:
8# https://github.com/openai/triton/pull/3172
9try:
10    from triton.language.extra import libdevice
11
12    libdevice = tl.extra.libdevice  # noqa: F811
13    math = tl.math
14except ImportError:
15    if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"):
16        libdevice = tl.extra.cuda.libdevice
17        math = tl.math
18    elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"):
19        libdevice = tl.extra.intel.libdevice
20        math = tl.math
21    else:
22        libdevice = tl.math
23        math = tl
24
25
26try:
27    from triton.language.standard import _log2
28except ImportError:
29
30    def _log2(x):
31        raise NotImplementedError
32
33
34@triton.jit
35def promote_to_tensor(x):
36    # Addition promotes to tensor for us
37    return x + tl.zeros((1,), tl.int1)
38
39
40@triton.jit
41def div_floor_integer(a, b):
42    # NOTE: a // b is C division, but we want floor division
43    # Based on c10::div_floor_integer
44    quot = a // b
45    remainder = a % b
46    fixed = tl.where(remainder != 0, quot - 1, quot)
47    return tl.where((a < 0) != (b < 0), fixed, quot)
48
49
50@triton.jit
51def remainder_integer(a, b):
52    # NOTE: a % b matches C division, not floor division
53    remainder = a % b
54    return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder)
55
56
57@triton.jit
58def is_floating(x):
59    return promote_to_tensor(x).dtype.is_floating()
60
61
62@triton.jit
63def _prod_accumulate(a, b):
64    return a * b
65
66
67@triton.jit
68def prod(input, axis):
69    return tl.reduce(input, axis, _prod_accumulate)
70
71
72@triton.jit
73def minimum(a, b):
74    mask = a < b
75    if is_floating(a):
76        mask |= a != a
77    return tl.where(mask, a, b)
78
79
80@triton.jit
81def maximum(a, b):
82    mask = a > b
83    if is_floating(a):
84        mask |= a != a
85    return tl.where(mask, a, b)
86
87
88@triton.jit
89def min2(a, dim):
90    return tl.reduce(a, dim, minimum)
91
92
93@triton.jit
94def max2(a, dim):
95    return tl.reduce(a, dim, maximum)
96
97
98@triton.jit
99def minimum_with_index(a_value, a_index, b_value, b_index):
100    mask = a_value < b_value
101    equal = a_value == b_value
102    if is_floating(a_value):
103        a_isnan = a_value != a_value
104        b_isnan = b_value != b_value
105        mask |= a_isnan and not b_isnan
106        # Consider NaNs as equal
107        equal |= a_isnan and b_isnan
108
109    # Prefer lowest index if values are equal
110    mask |= equal & (a_index < b_index)
111    return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
112
113
114@triton.jit
115def maximum_with_index(a_value, a_index, b_value, b_index):
116    mask = a_value > b_value
117    equal = a_value == b_value
118    if is_floating(a_value):
119        a_isnan = a_value != a_value
120        b_isnan = b_value != b_value
121        mask |= a_isnan and not b_isnan
122        # Consider NaNs as equal
123        equal |= a_isnan and b_isnan
124
125    # Prefer lowest index if values are equal
126    mask |= equal & (a_index < b_index)
127    return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
128
129
130@triton.jit
131def min_with_index(value, index, dim):
132    return tl.reduce((value, index), dim, minimum_with_index)
133
134
135@triton.jit
136def max_with_index(value, index, dim):
137    return tl.reduce((value, index), dim, maximum_with_index)
138
139
140@triton.jit
141def welford_reduce(value, mean, m2, weight, first_iteration):
142    if first_iteration:
143        new_weight = tl.full(weight.shape, 1, weight.dtype)
144        new_mean = value
145        new_m2 = tl.zeros_like(m2)
146    else:
147        delta = value - mean
148        new_weight = weight + 1
149        new_mean = mean + delta / new_weight
150        new_m2 = m2 + delta * (value - new_mean)
151    return new_mean, new_m2, new_weight
152
153
154@triton.jit
155def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
156    delta = mean_2 - mean_1
157    new_weight = weight_1 + weight_2
158    w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
159    return (
160        mean_1 + delta * w2_over_w,
161        m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
162        new_weight,
163    )
164
165
166@triton.jit
167def welford(mean, m2, weight, dim):
168    return tl.reduce((mean, m2, weight), dim, welford_combine)
169
170
171@triton.jit
172def device_assert_then(cond, msg, r):
173    tl.device_assert(cond, msg)
174    return r
175
176
177@triton.jit
178def randint64(seed, offset, low, high):
179    r0, r1, r2, r3 = tl.randint4x(seed, offset)
180    r0 = r0.to(tl.uint64)
181    r1 = r1.to(tl.uint64)
182    result = r0 | (r1 << 32)
183    size = high - low
184    result = result % size.to(tl.uint64)
185    result = result.to(tl.int64) + low
186    return result
187
188
189@triton.jit
190def _any_combine(a, b):
191    return a | b
192
193
194@triton.jit
195def any(a, dim):
196    return tl.reduce(a, dim, _any_combine)
197
198
199@triton.jit
200def bucketize_binary_search(
201    values,  # 1D tensor
202    offsets_ptr,
203    indexing_dtype,
204    right,  # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op]
205    OFFSETS_SIZE: int,
206    BLOCK_SHAPE,  # tuple/list of block shape
207):
208    """
209    See [Note: Inductor bucketize op]
210    """
211
212    low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype)
213    high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype)
214
215    full_range = OFFSETS_SIZE + 1
216    while full_range > 1:
217        mid = (high + low) // 2
218        mask = mid < OFFSETS_SIZE
219        bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0)
220        if right:
221            is_above = values >= bucket_upper_bound
222        else:
223            is_above = values > bucket_upper_bound
224
225        low = tl.where(is_above & mask, mid + 1, low)
226        high = tl.where(is_above, high, mid)
227
228        full_range = (full_range + 1) // 2
229
230    return low
231
232
233@triton.jit
234def pack_value_flag(
235    value,
236    flag,
237    DTYPE_VALUE_AS_UINT: tl.constexpr,
238    DTYPE_PACK: tl.constexpr,
239):
240    # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
241    DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
242    bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
243    uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
244    return flag.to(DTYPE_PACK) | (uv << bitwidth)
245
246
247@triton.jit
248def unpack_value(
249    pack,
250    DTYPE_VALUE,
251    DTYPE_VALUE_AS_UINT,
252):
253    # Workaround for triton bug, tensor.to doesn't unwrap constexpr values
254    DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE)
255    DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT)
256    bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
257    value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
258    return value_uint.to(DTYPE_VALUE, bitcast=True)
259
260
261@triton.jit
262def unpack_flag(pack, DTYPE_FLAG):
263    return pack.to(DTYPE_FLAG)
264
265
266@triton.jit
267def exclusive_scan_decoupled_lookback(
268    scratch_base,
269    block_value,
270    index,
271    combine_fn,
272    DTYPE_VALUE_AS_UINT: tl.constexpr,
273    DTYPE_PACK: tl.constexpr,
274):
275    """Compute exclusive scan of a scalar value between blocks
276
277    Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
278
279    scratch_base: Pointer to scratch space in global memory
280    block_value: Scalar value for this block
281    index: Scalar index of this block relative to the current scan
282    combine_fn: Function ``(value, value) -> value`` which is scanned over
283    DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
284    DTYPE_PACK: Unsigned type twice the width of block_value
285
286    NOTE: This function is limited to values which are 32-bits or less because
287    we need to pack (value, flag) into a single unsigned int.
288    """
289    # Publish block sum so subsequent blocks don't get stuck waiting for us
290    DTYPE_VALUE = block_value.dtype
291    pack = pack_value_flag(
292        block_value,
293        tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
294        DTYPE_VALUE_AS_UINT,
295        DTYPE_PACK,
296    )
297    if index > 0:
298        tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
299
300    # Calculate exclusive prefix scan
301    exclusive_prefix = tl.zeros([], DTYPE_VALUE)
302    prefix_valid = False
303    test_target = index - 1
304    while test_target >= 0:
305        # tl.atomic_load
306        flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
307        while flag == 0:
308            pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
309            flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
310
311        value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
312        if prefix_valid:
313            exclusive_prefix = combine_fn(value, exclusive_prefix)
314        else:
315            exclusive_prefix = value
316            prefix_valid = True
317
318        if flag == 2:
319            test_target = -1
320        else:
321            test_target = test_target - 1
322
323    # Make inclusive block sum visible to other blocks
324    if prefix_valid:
325        inclusive_prefix = combine_fn(exclusive_prefix, block_value)
326    else:
327        inclusive_prefix = block_value
328    pack = pack_value_flag(
329        inclusive_prefix,
330        tl.full([], 2, DTYPE_VALUE_AS_UINT),
331        DTYPE_VALUE_AS_UINT,
332        DTYPE_PACK,
333    )
334    tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
335    return exclusive_prefix
336
337
338@triton.jit
339def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):
340    """Compute exclusive scan of a scalar value between blocks
341
342    Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
343
344    scratch_base: Pointer to scratch space in global memory
345    block_value: Scalar value for this block, must be 64-bits wide
346    index: Scalar index of this block relative to the current scan
347    combine_fn: Function ``(value, value) -> value`` which is scanned over
348    init: Scalar value equal to the identiy of combine_fn
349    """
350    # Publish block sum so subsequent blocks don't get stuck waiting for us
351    if index > 0:
352        block_value_u64 = block_value.to(tl.uint64, bitcast=True)
353        tl.store(scratch_base + 3 * index + 1, block_value_u64)
354        tl.debug_barrier()
355        flag_one = tl.full([], 1, tl.uint64)
356        tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
357
358    # Calculate exclusive prefix scan
359    exclusive_prefix = tl.zeros([], block_value.dtype)
360    prefix_valid = False
361    test_target = index - 1
362    while test_target >= 0:
363        flag = tl.full([], 0, tl.uint64)
364        while flag == 0:
365            flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
366
367        value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
368        value = value_u64.to(block_value.dtype, bitcast=True)
369        if prefix_valid:
370            exclusive_prefix = combine_fn(value, exclusive_prefix)
371        else:
372            exclusive_prefix = value
373            prefix_valid = True
374
375        if flag == 2:
376            test_target = -1
377        else:
378            test_target = test_target - 1
379
380    # Make inclusive block sum visible to other blocks
381    if prefix_valid:
382        inclusive_prefix = combine_fn(exclusive_prefix, block_value)
383    else:
384        inclusive_prefix = block_value
385    inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
386    tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
387    tl.debug_barrier()
388    flag_two = tl.full([], 2, tl.uint64)
389    tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
390
391    return exclusive_prefix
392
393
394@triton.jit
395def frexp(x):
396    # TODO(isuruf): use inline_asm_elementwise here
397    y = libdevice.ilogb(x) + 1
398    exponent = tl.where(x == 0, 0, y)
399    mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))
400    return mantissa, exponent
401
402
403@triton.jit
404def _compare_and_swap_with_index(
405    x,
406    idxs,
407    rnumel,
408    flip,
409    i: tl.constexpr,
410    n_dims: tl.constexpr,
411    stable: tl.constexpr,
412    descending: tl.constexpr,
413):
414    n_outer: tl.constexpr = x.numel >> n_dims
415    shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
416
417    idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
418
419    y = tl.reshape(x, shape)
420    iy = y.to(idtype, bitcast=True)
421    # slice left/right with 'stride' 2**(n_dims - i - 1)
422    right_mask = tl.arange(0, 2)[None, :, None].to(idtype)
423    left_mask = (1 - right_mask).to(idtype)
424    ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape)
425    iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape)
426    ileft = tl.reshape(ileft, x.shape)
427    iright = tl.reshape(iright, x.shape)
428    left = ileft.to(x.dtype, bitcast=True)
429    right = iright.to(x.dtype, bitcast=True)
430
431    # idx
432    y_idx = tl.reshape(idxs, shape)
433    left_idx = tl.broadcast_to(
434        tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape
435    )
436    right_idx = tl.broadcast_to(
437        tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape
438    )
439    left_idx = tl.reshape(left_idx, x.shape)
440    right_idx = tl.reshape(right_idx, x.shape)
441
442    # valid
443    if rnumel is None:
444        left_valid_mask = tl.full(x.shape, True, tl.int1)
445        right_valid_mask = tl.full(x.shape, True, tl.int1)
446    else:
447        left_valid_mask = left_idx < rnumel
448        right_valid_mask = right_idx < rnumel
449
450    # actual compare-and-swap
451    ix = x.to(idtype, bitcast=True)
452
453    if descending:
454        cond = left < right
455    else:
456        cond = left > right
457
458    if stable:
459        # When stable sorting, tie break by index
460        cond = cond | ((left == right) & (left_idx > right_idx))
461
462    cond = (right_valid_mask > left_valid_mask) | (
463        (right_valid_mask == left_valid_mask) & cond
464    )
465    cond = cond ^ flip
466    ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
467    new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs))
468
469    return ret.to(x.dtype, bitcast=True), new_idxs
470
471
472@triton.jit
473def _bitonic_merge_with_index(
474    x,
475    idxs,
476    rnumel,
477    stage: tl.constexpr,
478    alternating: tl.constexpr,
479    n_dims: tl.constexpr,
480    stable: tl.constexpr,
481    descending: tl.constexpr,
482):
483    n_outer: tl.constexpr = x.numel >> n_dims
484    tl.static_assert(stage <= n_dims)
485    # flip denotes whether to re-arrange sub-sequences of elements in ascending or
486    # descending order.
487    # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
488    # if flip = 00110011... then all the elements will be re-arranged alternatingly (with
489    # a stride of 2) at this stage
490    if alternating:
491        shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
492        flip = tl.reshape(
493            tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape
494        )
495    else:
496        flip = False
497    # perform `stage` rounds of `compare-and-swap`
498    for i in tl.static_range(stage):
499        x, idxs = _compare_and_swap_with_index(
500            x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending
501        )
502    return x, idxs
503
504
505@triton.jit
506def sort_with_index(
507    x,  # value
508    idxs,  # index
509    rnumel,  # number of elements
510    dim: tl.constexpr = None,
511    stable: tl.constexpr = tl.constexpr(False),
512    descending: tl.constexpr = tl.constexpr(False),
513):
514    x, idxs = tl.broadcast(x, idxs)
515    # handle default dimension or check that it is the most minor dim
516    _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
517    tl.static_assert(
518        _dim == len(x.shape) - 1, "only minor dimension is currently supported"
519    )
520    # iteratively run bitonic merge-sort steps
521    n_dims: tl.constexpr = _log2(x.shape[_dim])
522
523    for i in tl.static_range(1, n_dims + 1):
524        x, idxs = _bitonic_merge_with_index(
525            x,
526            idxs,
527            rnumel,
528            i,
529            alternating=i < n_dims,
530            n_dims=n_dims,
531            stable=stable,
532            descending=descending,
533        )
534    return x, idxs
535
536
537@triton.jit
538def select_one(x, mask, dim, keep_dims=False):
539    idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False)
540    ix = x.to(idtype, bitcast=True)
541    iy = tl.sum(ix * mask, dim, keep_dims=keep_dims)
542    return iy.to(x.dtype, bitcast=True)
543