xref: /aosp_15_r20/external/pytorch/torchgen/_autoheuristic/pad_mm/gen_data_pad_mm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2import random
3import sys
4
5
6sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
8from typing import Any, Tuple
9
10from benchmark_runner import BenchmarkRunner  # type: ignore[import-not-found]
11from benchmark_utils import (  # type: ignore[import-not-found]
12    fits_in_memory,
13    get_mm_tensors,
14    set_precision,
15    transpose_tensors,
16)
17
18import torch
19from torch._inductor.fx_passes.pad_mm import (  # type: ignore[import-not-found]
20    get_alignment_size_dtype,
21)
22from torch._inductor.utils import fresh_inductor_cache
23
24
25class BenchmarkRunnerPadMM(BenchmarkRunner):  # type: ignore[misc, no-any-unimported]
26    """
27    BenchmarkRunner for pad_mm. Used to generate collect training data with AutoHeuristic to learn a heuristic.
28    """
29
30    def __init__(self) -> None:
31        super().__init__("pad_mm")
32
33    def create_input(self) -> Tuple[Any, ...]:
34        dtype = self.get_dtype()
35        set_precision(dtype)
36        m, k, n = self.get_m_k_n(dtype)
37
38        (transpose_left, transpose_right) = transpose_tensors()
39        prepadded_left = self.prepadded()
40        prepadded_right = self.prepadded()
41        return (
42            m,
43            k,
44            n,
45            transpose_left,
46            transpose_right,
47            dtype,
48            prepadded_left,
49            prepadded_right,
50        )
51
52    def run_benchmark(
53        self,
54        m: int,
55        k: int,
56        n: int,
57        transpose_left: bool,
58        transpose_right: bool,
59        dtype: Any,
60        prepadded_left: bool,
61        prepadded_right: bool,
62    ) -> None:
63        a, b = get_mm_tensors(
64            m,
65            k,
66            n,
67            transpose_left,
68            transpose_right,
69            dtype_left=dtype,
70            dtype_right=dtype,
71        )
72
73        print("Benchmarking the following input:")
74        print(f"m={m} k={k} n={n} dtype={dtype}")
75        print(f"transpose_left={transpose_left} transpose_right={transpose_right}")
76        print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}")
77
78        with fresh_inductor_cache():
79
80            def mm(a: Any, b: Any) -> Any:
81                return torch.mm(a, b)
82
83            def mm_mat1_prepadded(a: Any, b: Any) -> Any:
84                return torch.mm(a + 1, b)
85
86            def mm_mat2_prepadded(a: Any, b: Any) -> Any:
87                return torch.mm(a, b + 1)
88
89            def mm_mat1_mat2_prepadded(a: Any, b: Any) -> Any:
90                return torch.mm(a + 1, b + 1)
91
92            if prepadded_left and prepadded_right:
93                cf = torch.compile(mm_mat1_mat2_prepadded)
94            elif prepadded_left:
95                cf = torch.compile(mm_mat1_prepadded)
96            elif prepadded_right:
97                cf = torch.compile(mm_mat2_prepadded)
98            else:
99                cf = torch.compile(mm)
100            cf(a, b)
101            torch.compiler.reset()
102
103    def get_random_dim(
104        self, min_power2: int = 1, max_power2: int = 16, p_unaligned: float = 0.25
105    ) -> int:
106        aligned = random.choices([True, False], [1 - p_unaligned, p_unaligned])[0]
107        if aligned:
108            return 2 ** random.randint(min_power2, max_power2)  # type: ignore[no-any-return]
109        else:
110            # choose a random number between 2^i and 2^(i+1)
111            return self.get_random_between_pow2(min_power2, max_power2)  # type: ignore[no-any-return]
112
113    def is_aligned(self, dim: int, align_size: int) -> bool:
114        return dim % align_size == 0
115
116    def get_m_k_n(self, dtype: Any) -> Tuple[int, int, int]:
117        uniform = random.choices([True, False])[0]
118        align_size = get_alignment_size_dtype(dtype)
119
120        # repeat until tensors fit in memory
121        while True:
122            if uniform:
123                m = random.randint(1, 65536)
124                k = random.randint(1, 65536)
125                n = random.randint(1, 65536)
126            else:
127                m = self.get_random_dim()
128                k = self.get_random_dim()
129                n = self.get_random_dim()
130
131            if all(self.is_aligned(dim, align_size) for dim in [m, k, n]):
132                # skip if already aligned
133                continue
134
135            if fits_in_memory(dtype, m, k, n):
136                return (m, k, n)
137
138    def prepadded(self, p_prepadded: float = 0.2) -> bool:
139        # p_prepadded: probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking
140        return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0]
141
142    def get_dtype(self) -> Any:
143        dtype_choices = [torch.float16, torch.bfloat16, torch.float32]
144        return random.choices(dtype_choices)[0]
145
146
147if __name__ == "__main__":
148    runner = BenchmarkRunnerPadMM()
149    runner.run()
150