xref: /aosp_15_r20/external/pytorch/torch/_inductor/runtime/coordinate_descent_tuner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import itertools
4import logging
5from typing import Callable, Optional
6
7from .hints import TRITON_MAX_BLOCK
8from .runtime_utils import red_text, triton_config_to_hashable
9
10
11try:
12    import triton
13except ImportError:
14    triton = None
15
16log = logging.getLogger(__name__)
17
18
19def get_field(config, name):
20    if name == "num_warps":
21        return config.num_warps
22    elif name == "num_stages":
23        return config.num_stages
24    else:
25        return config.kwargs.get(name, None)
26
27
28def set_field(config, name, value):
29    if name == "num_warps":
30        config.num_warps = value
31    elif name == "num_stages":
32        config.num_stages = value
33    else:
34        config.kwargs[name] = value
35
36
37class CoordescTuner:
38    """
39    The coordinate descent tuner. Tune one field/coordinate at a time.
40
41    TODO will it be necessary to tune multiple fields simultaneously.
42
43
44    TODO: what if both increasing and decreasing a field can improve perf.
45          i.e., there are multiple local optima..
46    """
47
48    def __init__(
49        self, is_mm=False, name="unknown", size_hints=None, inductor_meta=None
50    ):
51        self.is_mm = is_mm  # we will tune num_stages for mm
52        self.cached_benchmark_results = {}
53        self.name = name
54        self.size_hints = size_hints
55        self.inductor_meta = inductor_meta or {}
56
57    def prefix_to_size_hint(self, prefix: str) -> Optional[int]:
58        size_hint_idx = {"X": 0, "Y": 1, "Z": 2, "R": -1}[prefix]
59
60        have_size_hint = (
61            self.size_hints is not None
62            and len(self.size_hints) > 0
63            and len(self.size_hints) > size_hint_idx
64        )
65        return self.size_hints[size_hint_idx] if have_size_hint else None
66
67    def get_config_max(self, prefix: str) -> int:
68        max_block = TRITON_MAX_BLOCK[prefix]
69        size_hint = self.prefix_to_size_hint(prefix)
70        return min(max_block, size_hint) if size_hint is not None else max_block
71
72    def get_warpsmax(self):
73        # Currently, CUDA has a maximum of 1024 threads, so 32 is the max
74        # number of warps.
75        return 1024 // 32
76
77    def cache_benchmark_result(self, config, timing):
78        self.cached_benchmark_results[triton_config_to_hashable(config)] = timing
79
80    def lookup_in_cache(self, config):
81        return self.cached_benchmark_results.get(triton_config_to_hashable(config))
82
83    def call_func(self, func, config):
84        found = self.lookup_in_cache(config)
85        if found is not None:
86            log.debug("  CACHED")
87            return found
88        timing = func(config)
89        self.cache_benchmark_result(config, timing)
90        return timing
91
92    @property
93    def tunable_fields(self):
94        out = [
95            "XBLOCK",
96            "YBLOCK",
97            "ZBLOCK",
98            # NOTE: we should not tune RBLOCK for persistent reduction.
99            # We rely on the fact that persistent reduction's triton.Config
100            # does not have the RBLOCK field to guarantee that.
101            "RBLOCK",
102            # the following 3 are for mm
103            "BLOCK_M",
104            "BLOCK_N",
105            "BLOCK_K",
106            "num_warps",
107        ]
108        if self.is_mm:
109            out.append("num_stages")
110
111        return out
112
113    def value_too_large(self, name: str, val: int) -> bool:
114        if name in {"XBLOCK", "YBLOCK", "ZBLOCK", "RBLOCK"}:
115            return val > self.get_config_max(name[0])
116        if name == "num_warps":
117            return val > self.get_warpsmax()
118
119        return False
120
121    def get_neighbour_values(self, name, orig_val, radius=1, include_self=False):
122        """
123        Get neighbour values in 'radius' steps. The original value is not
124        returned as it's own neighbour.
125        """
126        assert radius >= 1
127
128        def update(cur_val, inc=True):
129            if name == "num_stages":
130                if inc:
131                    return cur_val + 1
132                else:
133                    return cur_val - 1
134            else:
135                if inc:
136                    return cur_val * 2
137                else:
138                    return cur_val // 2
139
140        out = []
141        # increment loop
142        cur_val = orig_val
143        for _ in range(radius):
144            cur_val = update(cur_val, True)
145            if self.value_too_large(name, cur_val):
146                break
147            out.append(cur_val)
148
149        # decrement loop
150        cur_val = orig_val
151        for _ in range(radius):
152            cur_val = update(cur_val, False)
153            if cur_val <= 0:
154                break
155            out.append(cur_val)
156
157        if include_self:
158            out.append(orig_val)
159        return out
160
161    @staticmethod
162    def has_improvement(baseline, test):
163        threshold = 0.001  # 0.1%
164        return test is not None and test < baseline * (1 - threshold)
165
166    def check_all_tuning_directions(
167        self,
168        func: Callable[["triton.Config"], float],
169        best_config,
170        best_timing,
171    ):
172        """
173        Check all directions. We only do this once the regular coordinate
174        descent tuning find no better choices any more.
175        We only have a few tunable fields, so this should be fine.
176        """
177        candidate_values_list = []
178        effective_fields = []
179        for field in self.tunable_fields:
180            old_value = get_field(best_config, field)
181            if old_value is None:
182                continue
183            candidate_values = self.get_neighbour_values(
184                field,
185                old_value,
186                radius=self.inductor_meta.get("coordinate_descent_search_radius", 1),
187                include_self=True,
188            )
189            candidate_values_list.append(candidate_values)
190            effective_fields.append(field)
191
192        choices = itertools.product(*candidate_values_list)
193        improved = False
194        for choice in choices:
195            assert len(choice) == len(effective_fields)
196            candidate_config = copy.deepcopy(best_config)
197            for new_val, field in zip(choice, effective_fields):
198                set_field(candidate_config, field, new_val)
199            cmp_res, candidate_timing = self.compare_config(
200                func, candidate_config, best_config, best_timing
201            )
202            if cmp_res:
203                improved = True
204                best_config = candidate_config
205                best_timing = candidate_timing
206
207        return improved, best_config, best_timing
208
209    def compare_config(self, func, candidate_config, best_config, best_timing):
210        """
211        Check if candidate_config is better than best_config.
212
213        Return a touple of (compare_result, candidate_timing).
214        compare_result is true iff candidate_config is better.
215        """
216        log.debug("Try config %s", candidate_config)
217        try:
218            candidate_timing = self.call_func(func, candidate_config)
219        except Exception as e:
220            log.debug("Got exception %s", e)
221            return False, float("inf")
222
223        if self.has_improvement(best_timing, candidate_timing):
224            log.debug(
225                "Tune from %s %f -> %s %f",
226                best_config,
227                best_timing,
228                candidate_config,
229                candidate_timing,
230            )
231
232            return True, candidate_timing
233        return False, candidate_timing
234
235    def autotune(
236        self,
237        func: Callable[["triton.Config"], float],
238        baseline_config: "triton.Config",
239        baseline_timing: Optional[float] = None,
240    ) -> "triton.Config":
241        if baseline_timing is None:
242            baseline_timing = self.call_func(func, baseline_config)
243
244        log.debug("= Do coordinate descent tuning for %s =", self.name)
245        log.debug(
246            "Baseline Config %s, baseline timing %f", baseline_config, baseline_timing
247        )
248        improved = True
249        best_config = baseline_config
250        best_timing = baseline_timing
251        tunable_fields = self.tunable_fields
252
253        while improved:
254            improved = False
255
256            for name in tunable_fields:
257                cur_val = get_field(best_config, name)
258                # some kernel don't have RBLOCK/YBLOCK/ZBLOCK. So cur_val may be None
259                if cur_val is None:
260                    continue
261
262                # It's possible that candidate_values is empty.
263                # E.g., if XBLOCK is 1 initially and size_hint for x is also 1.
264                # We would not try either larger or smaller XBLOCK in this case.
265                candidate_values = self.get_neighbour_values(name, cur_val)
266
267                for next_val in candidate_values:
268                    candidate_config = copy.deepcopy(best_config)
269                    set_field(candidate_config, name, next_val)
270
271                    cmp_res, candidate_timing = self.compare_config(
272                        func, candidate_config, best_config, best_timing
273                    )
274                    if cmp_res:
275                        improved = True
276                        best_config, best_timing = candidate_config, candidate_timing
277
278            if not improved and self.inductor_meta.get(
279                "coordinate_descent_check_all_directions"
280            ):
281                old_best_timing = best_timing
282                improved, best_config, best_timing = self.check_all_tuning_directions(
283                    func, best_config, best_timing
284                )
285
286                if improved:
287                    msg = red_text(
288                        "Coordinate descend tuning found improvement of %.3fx by looking in all directions."
289                    )
290                    log.debug(
291                        msg,
292                        old_best_timing / best_timing,
293                    )
294
295        log.debug(
296            "Improve from %s %f -> %s %f, %.3fx",
297            baseline_config,
298            baseline_timing,
299            best_config,
300            best_timing,
301            baseline_timing / best_timing,
302        )
303
304        return best_config
305