1# mypy: allow-untyped-decorators 2# mypy: allow-untyped-defs 3import triton 4import triton.language as tl 5 6 7# In the latest triton, math functions were shuffled around into different modules: 8# https://github.com/openai/triton/pull/3172 9try: 10 from triton.language.extra import libdevice 11 12 libdevice = tl.extra.libdevice # noqa: F811 13 math = tl.math 14except ImportError: 15 if hasattr(tl.extra, "cuda") and hasattr(tl.extra.cuda, "libdevice"): 16 libdevice = tl.extra.cuda.libdevice 17 math = tl.math 18 elif hasattr(tl.extra, "intel") and hasattr(tl.extra.intel, "libdevice"): 19 libdevice = tl.extra.intel.libdevice 20 math = tl.math 21 else: 22 libdevice = tl.math 23 math = tl 24 25 26try: 27 from triton.language.standard import _log2 28except ImportError: 29 30 def _log2(x): 31 raise NotImplementedError 32 33 34@triton.jit 35def promote_to_tensor(x): 36 # Addition promotes to tensor for us 37 return x + tl.zeros((1,), tl.int1) 38 39 40@triton.jit 41def div_floor_integer(a, b): 42 # NOTE: a // b is C division, but we want floor division 43 # Based on c10::div_floor_integer 44 quot = a // b 45 remainder = a % b 46 fixed = tl.where(remainder != 0, quot - 1, quot) 47 return tl.where((a < 0) != (b < 0), fixed, quot) 48 49 50@triton.jit 51def remainder_integer(a, b): 52 # NOTE: a % b matches C division, not floor division 53 remainder = a % b 54 return tl.where(remainder != 0 and ((a < 0) != (b < 0)), remainder + b, remainder) 55 56 57@triton.jit 58def is_floating(x): 59 return promote_to_tensor(x).dtype.is_floating() 60 61 62@triton.jit 63def _prod_accumulate(a, b): 64 return a * b 65 66 67@triton.jit 68def prod(input, axis): 69 return tl.reduce(input, axis, _prod_accumulate) 70 71 72@triton.jit 73def minimum(a, b): 74 mask = a < b 75 if is_floating(a): 76 mask |= a != a 77 return tl.where(mask, a, b) 78 79 80@triton.jit 81def maximum(a, b): 82 mask = a > b 83 if is_floating(a): 84 mask |= a != a 85 return tl.where(mask, a, b) 86 87 88@triton.jit 89def min2(a, dim): 90 return tl.reduce(a, dim, minimum) 91 92 93@triton.jit 94def max2(a, dim): 95 return tl.reduce(a, dim, maximum) 96 97 98@triton.jit 99def minimum_with_index(a_value, a_index, b_value, b_index): 100 mask = a_value < b_value 101 equal = a_value == b_value 102 if is_floating(a_value): 103 a_isnan = a_value != a_value 104 b_isnan = b_value != b_value 105 mask |= a_isnan and not b_isnan 106 # Consider NaNs as equal 107 equal |= a_isnan and b_isnan 108 109 # Prefer lowest index if values are equal 110 mask |= equal & (a_index < b_index) 111 return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) 112 113 114@triton.jit 115def maximum_with_index(a_value, a_index, b_value, b_index): 116 mask = a_value > b_value 117 equal = a_value == b_value 118 if is_floating(a_value): 119 a_isnan = a_value != a_value 120 b_isnan = b_value != b_value 121 mask |= a_isnan and not b_isnan 122 # Consider NaNs as equal 123 equal |= a_isnan and b_isnan 124 125 # Prefer lowest index if values are equal 126 mask |= equal & (a_index < b_index) 127 return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index) 128 129 130@triton.jit 131def min_with_index(value, index, dim): 132 return tl.reduce((value, index), dim, minimum_with_index) 133 134 135@triton.jit 136def max_with_index(value, index, dim): 137 return tl.reduce((value, index), dim, maximum_with_index) 138 139 140@triton.jit 141def welford_reduce(value, mean, m2, weight, first_iteration): 142 if first_iteration: 143 new_weight = tl.full(weight.shape, 1, weight.dtype) 144 new_mean = value 145 new_m2 = tl.zeros_like(m2) 146 else: 147 delta = value - mean 148 new_weight = weight + 1 149 new_mean = mean + delta / new_weight 150 new_m2 = m2 + delta * (value - new_mean) 151 return new_mean, new_m2, new_weight 152 153 154@triton.jit 155def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): 156 delta = mean_2 - mean_1 157 new_weight = weight_1 + weight_2 158 w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight) 159 return ( 160 mean_1 + delta * w2_over_w, 161 m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w, 162 new_weight, 163 ) 164 165 166@triton.jit 167def welford(mean, m2, weight, dim): 168 return tl.reduce((mean, m2, weight), dim, welford_combine) 169 170 171@triton.jit 172def device_assert_then(cond, msg, r): 173 tl.device_assert(cond, msg) 174 return r 175 176 177@triton.jit 178def randint64(seed, offset, low, high): 179 r0, r1, r2, r3 = tl.randint4x(seed, offset) 180 r0 = r0.to(tl.uint64) 181 r1 = r1.to(tl.uint64) 182 result = r0 | (r1 << 32) 183 size = high - low 184 result = result % size.to(tl.uint64) 185 result = result.to(tl.int64) + low 186 return result 187 188 189@triton.jit 190def _any_combine(a, b): 191 return a | b 192 193 194@triton.jit 195def any(a, dim): 196 return tl.reduce(a, dim, _any_combine) 197 198 199@triton.jit 200def bucketize_binary_search( 201 values, # 1D tensor 202 offsets_ptr, 203 indexing_dtype, 204 right, # bool: if true, use intervals closed on the left; see [Note: Inductor bucketize op] 205 OFFSETS_SIZE: int, 206 BLOCK_SHAPE, # tuple/list of block shape 207): 208 """ 209 See [Note: Inductor bucketize op] 210 """ 211 212 low = tl.zeros(BLOCK_SHAPE, dtype=indexing_dtype) 213 high = tl.full(BLOCK_SHAPE, OFFSETS_SIZE, dtype=indexing_dtype) 214 215 full_range = OFFSETS_SIZE + 1 216 while full_range > 1: 217 mid = (high + low) // 2 218 mask = mid < OFFSETS_SIZE 219 bucket_upper_bound = tl.load(offsets_ptr + mid, mask=mask, other=0.0) 220 if right: 221 is_above = values >= bucket_upper_bound 222 else: 223 is_above = values > bucket_upper_bound 224 225 low = tl.where(is_above & mask, mid + 1, low) 226 high = tl.where(is_above, high, mid) 227 228 full_range = (full_range + 1) // 2 229 230 return low 231 232 233@triton.jit 234def pack_value_flag( 235 value, 236 flag, 237 DTYPE_VALUE_AS_UINT: tl.constexpr, 238 DTYPE_PACK: tl.constexpr, 239): 240 # Workaround for triton bug, tensor.to doesn't unwrap constexpr values 241 DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) 242 bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth 243 uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK) 244 return flag.to(DTYPE_PACK) | (uv << bitwidth) 245 246 247@triton.jit 248def unpack_value( 249 pack, 250 DTYPE_VALUE, 251 DTYPE_VALUE_AS_UINT, 252): 253 # Workaround for triton bug, tensor.to doesn't unwrap constexpr values 254 DTYPE_VALUE = tl.core._constexpr_to_value(DTYPE_VALUE) 255 DTYPE_VALUE_AS_UINT = tl.core._constexpr_to_value(DTYPE_VALUE_AS_UINT) 256 bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth 257 value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT) 258 return value_uint.to(DTYPE_VALUE, bitcast=True) 259 260 261@triton.jit 262def unpack_flag(pack, DTYPE_FLAG): 263 return pack.to(DTYPE_FLAG) 264 265 266@triton.jit 267def exclusive_scan_decoupled_lookback( 268 scratch_base, 269 block_value, 270 index, 271 combine_fn, 272 DTYPE_VALUE_AS_UINT: tl.constexpr, 273 DTYPE_PACK: tl.constexpr, 274): 275 """Compute exclusive scan of a scalar value between blocks 276 277 Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back 278 279 scratch_base: Pointer to scratch space in global memory 280 block_value: Scalar value for this block 281 index: Scalar index of this block relative to the current scan 282 combine_fn: Function ``(value, value) -> value`` which is scanned over 283 DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value`` 284 DTYPE_PACK: Unsigned type twice the width of block_value 285 286 NOTE: This function is limited to values which are 32-bits or less because 287 we need to pack (value, flag) into a single unsigned int. 288 """ 289 # Publish block sum so subsequent blocks don't get stuck waiting for us 290 DTYPE_VALUE = block_value.dtype 291 pack = pack_value_flag( 292 block_value, 293 tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT), 294 DTYPE_VALUE_AS_UINT, 295 DTYPE_PACK, 296 ) 297 if index > 0: 298 tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") 299 300 # Calculate exclusive prefix scan 301 exclusive_prefix = tl.zeros([], DTYPE_VALUE) 302 prefix_valid = False 303 test_target = index - 1 304 while test_target >= 0: 305 # tl.atomic_load 306 flag = tl.full([], 0, DTYPE_VALUE_AS_UINT) 307 while flag == 0: 308 pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed") 309 flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT) 310 311 value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT) 312 if prefix_valid: 313 exclusive_prefix = combine_fn(value, exclusive_prefix) 314 else: 315 exclusive_prefix = value 316 prefix_valid = True 317 318 if flag == 2: 319 test_target = -1 320 else: 321 test_target = test_target - 1 322 323 # Make inclusive block sum visible to other blocks 324 if prefix_valid: 325 inclusive_prefix = combine_fn(exclusive_prefix, block_value) 326 else: 327 inclusive_prefix = block_value 328 pack = pack_value_flag( 329 inclusive_prefix, 330 tl.full([], 2, DTYPE_VALUE_AS_UINT), 331 DTYPE_VALUE_AS_UINT, 332 DTYPE_PACK, 333 ) 334 tl.atomic_xchg(scratch_base + index, pack, sem="relaxed") 335 return exclusive_prefix 336 337 338@triton.jit 339def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn): 340 """Compute exclusive scan of a scalar value between blocks 341 342 Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back 343 344 scratch_base: Pointer to scratch space in global memory 345 block_value: Scalar value for this block, must be 64-bits wide 346 index: Scalar index of this block relative to the current scan 347 combine_fn: Function ``(value, value) -> value`` which is scanned over 348 init: Scalar value equal to the identiy of combine_fn 349 """ 350 # Publish block sum so subsequent blocks don't get stuck waiting for us 351 if index > 0: 352 block_value_u64 = block_value.to(tl.uint64, bitcast=True) 353 tl.store(scratch_base + 3 * index + 1, block_value_u64) 354 tl.debug_barrier() 355 flag_one = tl.full([], 1, tl.uint64) 356 tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release") 357 358 # Calculate exclusive prefix scan 359 exclusive_prefix = tl.zeros([], block_value.dtype) 360 prefix_valid = False 361 test_target = index - 1 362 while test_target >= 0: 363 flag = tl.full([], 0, tl.uint64) 364 while flag == 0: 365 flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire") 366 367 value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32)) 368 value = value_u64.to(block_value.dtype, bitcast=True) 369 if prefix_valid: 370 exclusive_prefix = combine_fn(value, exclusive_prefix) 371 else: 372 exclusive_prefix = value 373 prefix_valid = True 374 375 if flag == 2: 376 test_target = -1 377 else: 378 test_target = test_target - 1 379 380 # Make inclusive block sum visible to other blocks 381 if prefix_valid: 382 inclusive_prefix = combine_fn(exclusive_prefix, block_value) 383 else: 384 inclusive_prefix = block_value 385 inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True) 386 tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64) 387 tl.debug_barrier() 388 flag_two = tl.full([], 2, tl.uint64) 389 tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release") 390 391 return exclusive_prefix 392 393 394@triton.jit 395def frexp(x): 396 # TODO(isuruf): use inline_asm_elementwise here 397 y = libdevice.ilogb(x) + 1 398 exponent = tl.where(x == 0, 0, y) 399 mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y)) 400 return mantissa, exponent 401 402 403@triton.jit 404def _compare_and_swap_with_index( 405 x, 406 idxs, 407 rnumel, 408 flip, 409 i: tl.constexpr, 410 n_dims: tl.constexpr, 411 stable: tl.constexpr, 412 descending: tl.constexpr, 413): 414 n_outer: tl.constexpr = x.numel >> n_dims 415 shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 416 417 idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) 418 419 y = tl.reshape(x, shape) 420 iy = y.to(idtype, bitcast=True) 421 # slice left/right with 'stride' 2**(n_dims - i - 1) 422 right_mask = tl.arange(0, 2)[None, :, None].to(idtype) 423 left_mask = (1 - right_mask).to(idtype) 424 ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape) 425 iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape) 426 ileft = tl.reshape(ileft, x.shape) 427 iright = tl.reshape(iright, x.shape) 428 left = ileft.to(x.dtype, bitcast=True) 429 right = iright.to(x.dtype, bitcast=True) 430 431 # idx 432 y_idx = tl.reshape(idxs, shape) 433 left_idx = tl.broadcast_to( 434 tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape 435 ) 436 right_idx = tl.broadcast_to( 437 tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape 438 ) 439 left_idx = tl.reshape(left_idx, x.shape) 440 right_idx = tl.reshape(right_idx, x.shape) 441 442 # valid 443 if rnumel is None: 444 left_valid_mask = tl.full(x.shape, True, tl.int1) 445 right_valid_mask = tl.full(x.shape, True, tl.int1) 446 else: 447 left_valid_mask = left_idx < rnumel 448 right_valid_mask = right_idx < rnumel 449 450 # actual compare-and-swap 451 ix = x.to(idtype, bitcast=True) 452 453 if descending: 454 cond = left < right 455 else: 456 cond = left > right 457 458 if stable: 459 # When stable sorting, tie break by index 460 cond = cond | ((left == right) & (left_idx > right_idx)) 461 462 cond = (right_valid_mask > left_valid_mask) | ( 463 (right_valid_mask == left_valid_mask) & cond 464 ) 465 cond = cond ^ flip 466 ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) 467 new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs)) 468 469 return ret.to(x.dtype, bitcast=True), new_idxs 470 471 472@triton.jit 473def _bitonic_merge_with_index( 474 x, 475 idxs, 476 rnumel, 477 stage: tl.constexpr, 478 alternating: tl.constexpr, 479 n_dims: tl.constexpr, 480 stable: tl.constexpr, 481 descending: tl.constexpr, 482): 483 n_outer: tl.constexpr = x.numel >> n_dims 484 tl.static_assert(stage <= n_dims) 485 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 486 # descending order. 487 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 488 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 489 # a stride of 2) at this stage 490 if alternating: 491 shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 492 flip = tl.reshape( 493 tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape 494 ) 495 else: 496 flip = False 497 # perform `stage` rounds of `compare-and-swap` 498 for i in tl.static_range(stage): 499 x, idxs = _compare_and_swap_with_index( 500 x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending 501 ) 502 return x, idxs 503 504 505@triton.jit 506def sort_with_index( 507 x, # value 508 idxs, # index 509 rnumel, # number of elements 510 dim: tl.constexpr = None, 511 stable: tl.constexpr = tl.constexpr(False), 512 descending: tl.constexpr = tl.constexpr(False), 513): 514 x, idxs = tl.broadcast(x, idxs) 515 # handle default dimension or check that it is the most minor dim 516 _dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim 517 tl.static_assert( 518 _dim == len(x.shape) - 1, "only minor dimension is currently supported" 519 ) 520 # iteratively run bitonic merge-sort steps 521 n_dims: tl.constexpr = _log2(x.shape[_dim]) 522 523 for i in tl.static_range(1, n_dims + 1): 524 x, idxs = _bitonic_merge_with_index( 525 x, 526 idxs, 527 rnumel, 528 i, 529 alternating=i < n_dims, 530 n_dims=n_dims, 531 stable=stable, 532 descending=descending, 533 ) 534 return x, idxs 535 536 537@triton.jit 538def select_one(x, mask, dim, keep_dims=False): 539 idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False) 540 ix = x.to(idtype, bitcast=True) 541 iy = tl.sum(ix * mask, dim, keep_dims=keep_dims) 542 return iy.to(x.dtype, bitcast=True) 543