1import functools 2from typing import Any, Callable, Dict, List, Tuple 3 4import torch 5 6 7Feedback = float 8Choice = str 9Value = Any 10 11CHOICE_COL = "choice" 12FEEDBACK_COL = "feedback" 13 14 15class AHFeature: 16 """ 17 The context, that AutoHeuristic stores, is a list of features. AutoHeuristic needs to know whether a feature is 18 categorical (i.e., not a continuous variable) to learn a machine learning model. 19 """ 20 21 def __init__(self, name: str, value: Value, is_categorical: bool = False) -> None: 22 self.name = name 23 self.value = value 24 self.is_categorical = is_categorical 25 26 27class AHOperation: 28 """ 29 AHOperation can be used to augment the data collected by AutoHeuristic. 30 One might for example store features like m, k, n, but also want to use 31 features like m*n, or k*n, to learn a heuristic. Instead of storing features 32 that can be created from the collected data, one can use AHOperation to 33 create new features from the collected data. 34 """ 35 36 def __init__( 37 self, name: str, func: Callable[[Any], Value], is_categorical: bool = False 38 ) -> None: 39 self.name = name 40 self.func = func 41 self.is_categorical = is_categorical 42 43 def apply_operation(self, data: Any) -> None: 44 data[self.name] = self.func(data) 45 46 47class AHContext: 48 """ 49 This class is used to specify which information AutoHeuristic should store. For each choice, AutoHeursitic will 50 store the context and the collected feedback. The context could be something like the shape of a tensor, i.e., 51 information that will help to learn a heuristic. 52 """ 53 54 features: List[AHFeature] 55 context_dict: Dict[str, Value] 56 57 def __init__(self) -> None: 58 self.features = [] 59 self.context_dict = {} 60 61 def add_feature( 62 self, name: str, value: Value, is_categorical: bool = False 63 ) -> None: 64 self.features.append(AHFeature(name, value, is_categorical=is_categorical)) 65 self.context_dict[name] = value 66 67 def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]: 68 numerical_features = [] 69 categorical_features = [] 70 for feature in self.features: 71 if feature.is_categorical: 72 categorical_features.append(feature.name) 73 else: 74 numerical_features.append(feature.name) 75 76 return numerical_features, categorical_features 77 78 def get_feature_names_csv(self) -> str: 79 return ",".join(feature.name for feature in self.features) 80 81 def get_feature_values_csv(self) -> str: 82 return ",".join(str(feature.value) for feature in self.features) 83 84 def get_value(self, name: str) -> Value: 85 return self.context_dict[name] 86 87 def apply_operations(self, operations: List[AHOperation]) -> None: 88 for op in operations: 89 op.apply_operation(self.context_dict) 90 91 92class AHMetadata: 93 def __init__( 94 self, 95 shared_memory: Any, 96 device_capa: Tuple[int, int], 97 choices: List[Choice], 98 name: str, 99 ) -> None: 100 # use amount of shared_memory and device_capability to identify GPU 101 # TODO(AlnisM): there might be a better way to do this 102 self.shared_memory = shared_memory 103 self.device_capa = device_capa 104 self.choices = choices 105 self.name = name 106 107 def to_dict(self) -> Dict[str, Value]: 108 return { 109 "shared_memory": self.shared_memory, 110 "device_capa": self.device_capa, 111 "name": self.name, 112 } 113 114 115def get_metadata_str_from_log(log_path: str) -> str: 116 with open(log_path, newline="") as file: 117 json_string = file.readline().strip() 118 return json_string 119 120 121def check_minsize(context: AHContext, minsize: int) -> bool: 122 return ( 123 context.get_value("m") >= minsize 124 and context.get_value("k") >= minsize 125 and context.get_value("n") >= minsize 126 ) 127 128 129def pad_mm_precondition(metadata: AHMetadata, context: AHContext) -> bool: 130 if metadata.shared_memory == 166912 and metadata.device_capa == (8, 0): 131 # A100 precondition 132 return check_minsize(context, 512) 133 elif metadata.shared_memory == 232448 and metadata.device_capa == (9, 0): 134 # H100 precondition 135 return check_minsize(context, 768) 136 return True 137 138 139def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool: 140 m = context.get_value("m") 141 k = context.get_value("k") 142 n = context.get_value("n") 143 if m > 128 or k < 1024 or n < 1024: 144 return False 145 mat1_iscontig = context.get_value("mat1_iscontig") 146 mat2_iscontig = context.get_value("mat2_iscontig") 147 return mat1_iscontig and not mat2_iscontig 148 149 150def get_mult_dims_ops() -> List[AHOperation]: 151 m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"]) 152 m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"]) 153 k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"]) 154 return [m_times_k_op, m_times_n_op, k_times_n_op] 155 156 157def get_arith_intensity(data: Any) -> float: 158 m = data["m"] 159 k = data["k"] 160 n = data["n"] 161 if m == 0 or k == 0 or n == 0: 162 return 0.0 163 return m * k * n / (m * k + k * n + m * n) 164 165 166def pad_mm_operations() -> List[AHOperation]: 167 mult_dims_ops = get_mult_dims_ops() 168 k_div_m_times_n_op = AHOperation( 169 "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"]) 170 ) 171 172 def bfloat_perf_hit(data: Any) -> bool: 173 m = data["m"] 174 k = data["k"] 175 n = data["n"] 176 is_bfloat = str(data["mat1_dtype"]) == "torch.bfloat16" 177 return k > (m * 1024) and k > (n * 1024) and is_bfloat 178 179 bfloat_perf_hit_op = AHOperation( 180 "bfloat_perf_hit", bfloat_perf_hit, is_categorical=True 181 ) 182 183 arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) 184 dims_need_padding_ops = get_dims_need_padding_ops() 185 dims_multiple_ops = get_dims_multiple_ops() 186 is_contig_ops = get_is_contig_ops() 187 188 ah_operations = mult_dims_ops + [ 189 k_div_m_times_n_op, 190 bfloat_perf_hit_op, 191 arith_intensity_op, 192 ] 193 ah_operations.extend(dims_need_padding_ops) 194 ah_operations.extend(dims_multiple_ops) 195 ah_operations.extend(is_contig_ops) 196 return ah_operations 197 198 199def between_op(data: Any, dim: str, lower: int, upper: int) -> bool: 200 return data[dim] >= lower and data[dim] <= upper 201 202 203def between_ops() -> List[AHOperation]: 204 dims = ["m", "k", "n"] 205 limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)] 206 ah_operations = [] 207 for dim in dims: 208 for lower, upper in limits: 209 between_op_fn = functools.partial( 210 between_op, dim=dim, lower=lower, upper=upper 211 ) 212 # using 'LEQ' instead of '<=' because '<=' cannot be exported to dot 213 between_op_name = f"{lower}LEQ{dim}LEQ{upper}" 214 ah_operations.append( 215 AHOperation(between_op_name, between_op_fn, is_categorical=True) 216 ) 217 return ah_operations 218 219 220def pow2_op(data: Any, dim: str, exponent: int) -> bool: 221 return data[dim] == 2**exponent 222 223 224def mm_operations() -> List[AHOperation]: 225 mult_dims_ops = get_mult_dims_ops() 226 arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) 227 return mult_dims_ops + [arith_intensity_op] 228 229 230def mixed_mm_operations() -> List[AHOperation]: 231 return mm_operations() + between_ops() 232 233 234def is_multiple(data: Any, dim: str, mult: int) -> bool: 235 return data[dim] % mult == 0 236 237 238def get_dims_multiple_ops() -> List[AHOperation]: 239 multiples = [2, 4, 8, 16, 32] 240 dims = ["m", "k", "n"] 241 dims_multiple_ops = [] 242 for dim in dims: 243 for mult in multiples: 244 is_multiple_fn = functools.partial(is_multiple, dim=dim, mult=mult) 245 dims_multiple_op = AHOperation( 246 f"{dim}_multiple_{mult}", is_multiple_fn, is_categorical=True 247 ) 248 dims_multiple_ops.append(dims_multiple_op) 249 return dims_multiple_ops 250 251 252def get_dims_need_padding_ops() -> List[AHOperation]: 253 def mat1_innermost_needs_padding_fn(data: Any) -> bool: 254 mat1_stride_0 = data["mat1_stride_0"] 255 mat1_stride_1 = data["mat1_stride_1"] 256 m_padded_length = data["m_padded_length"] 257 k_padded_length = data["k_padded_length"] 258 mat1_innermost_needs_padding = False 259 if mat1_stride_0 == 1 and m_padded_length != 0: 260 mat1_innermost_needs_padding = True 261 if mat1_stride_1 == 1 and k_padded_length != 0: 262 mat1_innermost_needs_padding = True 263 return mat1_innermost_needs_padding 264 265 mat1_innermost_op = AHOperation( 266 "mat1_innermost_needs_padding", 267 mat1_innermost_needs_padding_fn, 268 is_categorical=True, 269 ) 270 271 def mat2_innermost_needs_padding_fn(data: Any) -> bool: 272 mat2_stride_0 = data["mat2_stride_0"] 273 mat2_stride_1 = data["mat2_stride_1"] 274 k_padded_length = data["k_padded_length"] 275 n_padded_length = data["n_padded_length"] 276 mat2_innermost_needs_padding = False 277 if mat2_stride_0 == 1 and k_padded_length != 0: 278 mat2_innermost_needs_padding = True 279 if mat2_stride_1 == 1 and n_padded_length != 0: 280 mat2_innermost_needs_padding = True 281 return mat2_innermost_needs_padding 282 283 mat2_innermost_op = AHOperation( 284 "mat2_innermost_needs_padding", 285 mat2_innermost_needs_padding_fn, 286 is_categorical=True, 287 ) 288 289 def num_dims_needs_padding_fn(data: Any) -> int: 290 m_padded_length = data["m_padded_length"] 291 k_padded_length = data["k_padded_length"] 292 n_padded_length = data["n_padded_length"] 293 num_dims_needs_padding = 0 294 if m_padded_length != 0: 295 num_dims_needs_padding += 1 296 if k_padded_length != 0: 297 num_dims_needs_padding += 1 298 if n_padded_length != 0: 299 num_dims_needs_padding += 1 300 return num_dims_needs_padding 301 302 num_dims_op = AHOperation("num_dims_needs_padding", num_dims_needs_padding_fn) 303 return [mat1_innermost_op, mat2_innermost_op, num_dims_op] 304 305 306def get_is_contig_ops() -> List[AHOperation]: 307 def mat1_is_contig_fn(data: Any) -> bool: 308 stride_0 = data["mat1_stride_0"] 309 stride_1 = data["mat1_stride_1"] 310 k = data["k"] 311 return stride_0 == k and stride_1 == 1 312 313 mat1_is_contig_op = AHOperation( 314 "mat1_iscontig", mat1_is_contig_fn, is_categorical=True 315 ) 316 317 def mat2_is_contig_fn(data: Any) -> bool: 318 stride_0 = data["mat2_stride_0"] 319 stride_1 = data["mat2_stride_1"] 320 n = data["n"] 321 return stride_0 == n and stride_1 == 1 322 323 mat2_is_contig_op = AHOperation( 324 "mat2_iscontig", mat2_is_contig_fn, is_categorical=True 325 ) 326 327 return [mat1_is_contig_op, mat2_is_contig_op] 328 329 330def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None: 331 for i, s in enumerate(stride): 332 context.add_feature(f"{name}_stride_{i}", s) 333 334 335def context_add_using_tf32(context: AHContext, dtype: torch.dtype) -> None: 336 using_tf32 = "not_float_32" 337 if dtype == torch.float32: 338 using_tf32 = torch.backends.cuda.matmul.allow_tf32 339 context.add_feature("using_tf32", using_tf32, is_categorical=True) 340