xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/benchmark_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import random
2from typing import Any, Tuple
3
4import torch
5
6
7def transpose_tensors(p_transpose_both: float = 0.05) -> Tuple[bool, bool]:
8    transpose_both = random.choices(
9        [True, False], [p_transpose_both, 1 - p_transpose_both]
10    )[0]
11    if transpose_both:
12        return (True, True)
13    transpose_left = (True, False)
14    transpose_right = (False, True)
15    no_transpose = (False, False)
16    return random.choices([transpose_left, transpose_right, no_transpose])[0]
17
18
19def fits_in_memory(dtype: Any, m: int, k: int, n: int) -> Any:
20    threshold_memory = torch.cuda.get_device_properties(0).total_memory / 4
21    # dividing by 4 beause we otherwise sometimes run out of memory, I assume because
22    # inductor creates copies of tensors for benchmarking?
23    return dtype.itemsize * (m * k + k * n + m * n) < threshold_memory
24
25
26def get_mm_tensors(
27    m: int,
28    k: int,
29    n: int,
30    transpose_left: bool,
31    transpose_right: bool,
32    dtype_left: Any,
33    dtype_right: Any,
34) -> Tuple[Any, Any]:
35    if transpose_left:
36        a = torch.randn(k, m, dtype=dtype_left).t()
37    else:
38        a = torch.randn(m, k, dtype=dtype_left)
39
40    if transpose_right:
41        b = torch.randn(n, k, dtype=dtype_right).t()
42    else:
43        b = torch.randn(k, n, dtype=dtype_right)
44    return (a, b)
45
46
47def set_precision(dtype: Any, p_float32_prec_highest: float = 0.8) -> None:
48    if dtype == torch.float32:
49        precisions = ["high", "highest"]
50        weights = [1 - p_float32_prec_highest, p_float32_prec_highest]
51        precision = random.choices(precisions, weights)[0]
52    else:
53        precision = "high"
54    torch.set_float32_matmul_precision(precision)
55
56
57def get_random_between_pow2(min_power2: int, max_power2: int) -> int:
58    i = random.randint(min_power2, max_power2 - 1)
59    lower = 2**i + 1
60    upper = 2 ** (i + 1) - 1
61    assert lower <= upper, "lower must not be greater than upper"
62    return random.randint(lower, upper)
63