1# mypy: allow-untyped-defs 2import copy 3import itertools 4import logging 5from typing import Callable, Optional 6 7from .hints import TRITON_MAX_BLOCK 8from .runtime_utils import red_text, triton_config_to_hashable 9 10 11try: 12 import triton 13except ImportError: 14 triton = None 15 16log = logging.getLogger(__name__) 17 18 19def get_field(config, name): 20 if name == "num_warps": 21 return config.num_warps 22 elif name == "num_stages": 23 return config.num_stages 24 else: 25 return config.kwargs.get(name, None) 26 27 28def set_field(config, name, value): 29 if name == "num_warps": 30 config.num_warps = value 31 elif name == "num_stages": 32 config.num_stages = value 33 else: 34 config.kwargs[name] = value 35 36 37class CoordescTuner: 38 """ 39 The coordinate descent tuner. Tune one field/coordinate at a time. 40 41 TODO will it be necessary to tune multiple fields simultaneously. 42 43 44 TODO: what if both increasing and decreasing a field can improve perf. 45 i.e., there are multiple local optima.. 46 """ 47 48 def __init__( 49 self, is_mm=False, name="unknown", size_hints=None, inductor_meta=None 50 ): 51 self.is_mm = is_mm # we will tune num_stages for mm 52 self.cached_benchmark_results = {} 53 self.name = name 54 self.size_hints = size_hints 55 self.inductor_meta = inductor_meta or {} 56 57 def prefix_to_size_hint(self, prefix: str) -> Optional[int]: 58 size_hint_idx = {"X": 0, "Y": 1, "Z": 2, "R": -1}[prefix] 59 60 have_size_hint = ( 61 self.size_hints is not None 62 and len(self.size_hints) > 0 63 and len(self.size_hints) > size_hint_idx 64 ) 65 return self.size_hints[size_hint_idx] if have_size_hint else None 66 67 def get_config_max(self, prefix: str) -> int: 68 max_block = TRITON_MAX_BLOCK[prefix] 69 size_hint = self.prefix_to_size_hint(prefix) 70 return min(max_block, size_hint) if size_hint is not None else max_block 71 72 def get_warpsmax(self): 73 # Currently, CUDA has a maximum of 1024 threads, so 32 is the max 74 # number of warps. 75 return 1024 // 32 76 77 def cache_benchmark_result(self, config, timing): 78 self.cached_benchmark_results[triton_config_to_hashable(config)] = timing 79 80 def lookup_in_cache(self, config): 81 return self.cached_benchmark_results.get(triton_config_to_hashable(config)) 82 83 def call_func(self, func, config): 84 found = self.lookup_in_cache(config) 85 if found is not None: 86 log.debug(" CACHED") 87 return found 88 timing = func(config) 89 self.cache_benchmark_result(config, timing) 90 return timing 91 92 @property 93 def tunable_fields(self): 94 out = [ 95 "XBLOCK", 96 "YBLOCK", 97 "ZBLOCK", 98 # NOTE: we should not tune RBLOCK for persistent reduction. 99 # We rely on the fact that persistent reduction's triton.Config 100 # does not have the RBLOCK field to guarantee that. 101 "RBLOCK", 102 # the following 3 are for mm 103 "BLOCK_M", 104 "BLOCK_N", 105 "BLOCK_K", 106 "num_warps", 107 ] 108 if self.is_mm: 109 out.append("num_stages") 110 111 return out 112 113 def value_too_large(self, name: str, val: int) -> bool: 114 if name in {"XBLOCK", "YBLOCK", "ZBLOCK", "RBLOCK"}: 115 return val > self.get_config_max(name[0]) 116 if name == "num_warps": 117 return val > self.get_warpsmax() 118 119 return False 120 121 def get_neighbour_values(self, name, orig_val, radius=1, include_self=False): 122 """ 123 Get neighbour values in 'radius' steps. The original value is not 124 returned as it's own neighbour. 125 """ 126 assert radius >= 1 127 128 def update(cur_val, inc=True): 129 if name == "num_stages": 130 if inc: 131 return cur_val + 1 132 else: 133 return cur_val - 1 134 else: 135 if inc: 136 return cur_val * 2 137 else: 138 return cur_val // 2 139 140 out = [] 141 # increment loop 142 cur_val = orig_val 143 for _ in range(radius): 144 cur_val = update(cur_val, True) 145 if self.value_too_large(name, cur_val): 146 break 147 out.append(cur_val) 148 149 # decrement loop 150 cur_val = orig_val 151 for _ in range(radius): 152 cur_val = update(cur_val, False) 153 if cur_val <= 0: 154 break 155 out.append(cur_val) 156 157 if include_self: 158 out.append(orig_val) 159 return out 160 161 @staticmethod 162 def has_improvement(baseline, test): 163 threshold = 0.001 # 0.1% 164 return test is not None and test < baseline * (1 - threshold) 165 166 def check_all_tuning_directions( 167 self, 168 func: Callable[["triton.Config"], float], 169 best_config, 170 best_timing, 171 ): 172 """ 173 Check all directions. We only do this once the regular coordinate 174 descent tuning find no better choices any more. 175 We only have a few tunable fields, so this should be fine. 176 """ 177 candidate_values_list = [] 178 effective_fields = [] 179 for field in self.tunable_fields: 180 old_value = get_field(best_config, field) 181 if old_value is None: 182 continue 183 candidate_values = self.get_neighbour_values( 184 field, 185 old_value, 186 radius=self.inductor_meta.get("coordinate_descent_search_radius", 1), 187 include_self=True, 188 ) 189 candidate_values_list.append(candidate_values) 190 effective_fields.append(field) 191 192 choices = itertools.product(*candidate_values_list) 193 improved = False 194 for choice in choices: 195 assert len(choice) == len(effective_fields) 196 candidate_config = copy.deepcopy(best_config) 197 for new_val, field in zip(choice, effective_fields): 198 set_field(candidate_config, field, new_val) 199 cmp_res, candidate_timing = self.compare_config( 200 func, candidate_config, best_config, best_timing 201 ) 202 if cmp_res: 203 improved = True 204 best_config = candidate_config 205 best_timing = candidate_timing 206 207 return improved, best_config, best_timing 208 209 def compare_config(self, func, candidate_config, best_config, best_timing): 210 """ 211 Check if candidate_config is better than best_config. 212 213 Return a touple of (compare_result, candidate_timing). 214 compare_result is true iff candidate_config is better. 215 """ 216 log.debug("Try config %s", candidate_config) 217 try: 218 candidate_timing = self.call_func(func, candidate_config) 219 except Exception as e: 220 log.debug("Got exception %s", e) 221 return False, float("inf") 222 223 if self.has_improvement(best_timing, candidate_timing): 224 log.debug( 225 "Tune from %s %f -> %s %f", 226 best_config, 227 best_timing, 228 candidate_config, 229 candidate_timing, 230 ) 231 232 return True, candidate_timing 233 return False, candidate_timing 234 235 def autotune( 236 self, 237 func: Callable[["triton.Config"], float], 238 baseline_config: "triton.Config", 239 baseline_timing: Optional[float] = None, 240 ) -> "triton.Config": 241 if baseline_timing is None: 242 baseline_timing = self.call_func(func, baseline_config) 243 244 log.debug("= Do coordinate descent tuning for %s =", self.name) 245 log.debug( 246 "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing 247 ) 248 improved = True 249 best_config = baseline_config 250 best_timing = baseline_timing 251 tunable_fields = self.tunable_fields 252 253 while improved: 254 improved = False 255 256 for name in tunable_fields: 257 cur_val = get_field(best_config, name) 258 # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None 259 if cur_val is None: 260 continue 261 262 # It's possible that candidate_values is empty. 263 # E.g., if XBLOCK is 1 initially and size_hint for x is also 1. 264 # We would not try either larger or smaller XBLOCK in this case. 265 candidate_values = self.get_neighbour_values(name, cur_val) 266 267 for next_val in candidate_values: 268 candidate_config = copy.deepcopy(best_config) 269 set_field(candidate_config, name, next_val) 270 271 cmp_res, candidate_timing = self.compare_config( 272 func, candidate_config, best_config, best_timing 273 ) 274 if cmp_res: 275 improved = True 276 best_config, best_timing = candidate_config, candidate_timing 277 278 if not improved and self.inductor_meta.get( 279 "coordinate_descent_check_all_directions" 280 ): 281 old_best_timing = best_timing 282 improved, best_config, best_timing = self.check_all_tuning_directions( 283 func, best_config, best_timing 284 ) 285 286 if improved: 287 msg = red_text( 288 "Coordinate descend tuning found improvement of %.3fx by looking in all directions." 289 ) 290 log.debug( 291 msg, 292 old_best_timing / best_timing, 293 ) 294 295 log.debug( 296 "Improve from %s %f -> %s %f, %.3fx", 297 baseline_config, 298 baseline_timing, 299 best_config, 300 best_timing, 301 baseline_timing / best_timing, 302 ) 303 304 return best_config 305