1import os 2import random 3import sys 4 5 6sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 8from typing import Any, Tuple 9 10from benchmark_runner import BenchmarkRunner # type: ignore[import-not-found] 11from benchmark_utils import ( # type: ignore[import-not-found] 12 fits_in_memory, 13 get_mm_tensors, 14 set_precision, 15 transpose_tensors, 16) 17 18import torch 19from torch._inductor.fx_passes.pad_mm import ( # type: ignore[import-not-found] 20 get_alignment_size_dtype, 21) 22from torch._inductor.utils import fresh_inductor_cache 23 24 25class BenchmarkRunnerPadMM(BenchmarkRunner): # type: ignore[misc, no-any-unimported] 26 """ 27 BenchmarkRunner for pad_mm. Used to generate collect training data with AutoHeuristic to learn a heuristic. 28 """ 29 30 def __init__(self) -> None: 31 super().__init__("pad_mm") 32 33 def create_input(self) -> Tuple[Any, ...]: 34 dtype = self.get_dtype() 35 set_precision(dtype) 36 m, k, n = self.get_m_k_n(dtype) 37 38 (transpose_left, transpose_right) = transpose_tensors() 39 prepadded_left = self.prepadded() 40 prepadded_right = self.prepadded() 41 return ( 42 m, 43 k, 44 n, 45 transpose_left, 46 transpose_right, 47 dtype, 48 prepadded_left, 49 prepadded_right, 50 ) 51 52 def run_benchmark( 53 self, 54 m: int, 55 k: int, 56 n: int, 57 transpose_left: bool, 58 transpose_right: bool, 59 dtype: Any, 60 prepadded_left: bool, 61 prepadded_right: bool, 62 ) -> None: 63 a, b = get_mm_tensors( 64 m, 65 k, 66 n, 67 transpose_left, 68 transpose_right, 69 dtype_left=dtype, 70 dtype_right=dtype, 71 ) 72 73 print("Benchmarking the following input:") 74 print(f"m={m} k={k} n={n} dtype={dtype}") 75 print(f"transpose_left={transpose_left} transpose_right={transpose_right}") 76 print(f"prepadded_left={prepadded_left} prepadded_right={prepadded_right}") 77 78 with fresh_inductor_cache(): 79 80 def mm(a: Any, b: Any) -> Any: 81 return torch.mm(a, b) 82 83 def mm_mat1_prepadded(a: Any, b: Any) -> Any: 84 return torch.mm(a + 1, b) 85 86 def mm_mat2_prepadded(a: Any, b: Any) -> Any: 87 return torch.mm(a, b + 1) 88 89 def mm_mat1_mat2_prepadded(a: Any, b: Any) -> Any: 90 return torch.mm(a + 1, b + 1) 91 92 if prepadded_left and prepadded_right: 93 cf = torch.compile(mm_mat1_mat2_prepadded) 94 elif prepadded_left: 95 cf = torch.compile(mm_mat1_prepadded) 96 elif prepadded_right: 97 cf = torch.compile(mm_mat2_prepadded) 98 else: 99 cf = torch.compile(mm) 100 cf(a, b) 101 torch.compiler.reset() 102 103 def get_random_dim( 104 self, min_power2: int = 1, max_power2: int = 16, p_unaligned: float = 0.25 105 ) -> int: 106 aligned = random.choices([True, False], [1 - p_unaligned, p_unaligned])[0] 107 if aligned: 108 return 2 ** random.randint(min_power2, max_power2) # type: ignore[no-any-return] 109 else: 110 # choose a random number between 2^i and 2^(i+1) 111 return self.get_random_between_pow2(min_power2, max_power2) # type: ignore[no-any-return] 112 113 def is_aligned(self, dim: int, align_size: int) -> bool: 114 return dim % align_size == 0 115 116 def get_m_k_n(self, dtype: Any) -> Tuple[int, int, int]: 117 uniform = random.choices([True, False])[0] 118 align_size = get_alignment_size_dtype(dtype) 119 120 # repeat until tensors fit in memory 121 while True: 122 if uniform: 123 m = random.randint(1, 65536) 124 k = random.randint(1, 65536) 125 n = random.randint(1, 65536) 126 else: 127 m = self.get_random_dim() 128 k = self.get_random_dim() 129 n = self.get_random_dim() 130 131 if all(self.is_aligned(dim, align_size) for dim in [m, k, n]): 132 # skip if already aligned 133 continue 134 135 if fits_in_memory(dtype, m, k, n): 136 return (m, k, n) 137 138 def prepadded(self, p_prepadded: float = 0.2) -> bool: 139 # p_prepadded: probability that a tensor is "prepadded", i.e. pad_mm excludes time it takes to pad from benchmarking 140 return random.choices([True, False], [p_prepadded, 1 - p_prepadded])[0] 141 142 def get_dtype(self) -> Any: 143 dtype_choices = [torch.float16, torch.bfloat16, torch.float32] 144 return random.choices(dtype_choices)[0] 145 146 147if __name__ == "__main__": 148 runner = BenchmarkRunnerPadMM() 149 runner.run() 150