xref: /aosp_15_r20/external/pytorch/torch/_inductor/bounds.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3import operator
4from functools import partial
5from typing import Any, Callable, Dict
6
7from sympy import Expr
8
9import torch
10from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
11
12from .loop_body import InterpreterShim, LoopBody, LoopBodyBlock
13from .utils import cache_on_self, dominated_nodes
14from .virtualized import V
15
16
17log = logging.getLogger(__name__)
18
19
20class BoundVars:
21    """
22    Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run()
23    It exposes the ranges of the nodes in the `bounds` variable
24
25    Note. A current limitation of this analysis is that it just works on a per-loop basis.
26    We should be able to propagate the bounds between across the whole graph. This may benefit
27    the case a bounded variable is returned by a kernel and fed into another.
28    """
29
30    def __init__(self, loop_body: LoopBody) -> None:
31        def upper_bound(v):
32            return bound_sympy(v).upper if isinstance(v, Expr) else v
33
34        self.loop_body = loop_body
35        self.replacement_vals = {
36            k: ValueRanges[Expr](0, upper_bound(v) - 1)
37            for k, v in loop_body.var_ranges.items()
38        }
39        # avoid computing these values, pessimistically assume that they are unbounded
40        self.unbounded_vars = dominated_nodes(
41            node
42            for node in self.loop_body.get_nodes()
43            if node.target in ["load", "reduction", operator.getitem]
44            or "masked_subblock" in node.target
45        )
46        # To access this variable call `get_bounds()`
47        self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
48
49    def __repr__(self) -> str:
50        return (
51            f"{self.__class__.__name__}("
52            f"loop_body={self.loop_body},\n "
53            f"replacement_vals={self.replacement_vals}, \n"
54            f"unbounded_vars={self.unbounded_vars}, \n"
55            f"_bounds={self._bounds})"
56        )
57
58    @cache_on_self
59    def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
60        submodules = self.swap_submodules(self.loop_body.submodules)
61
62        # Initialize the environment with the unbounded variables
63        for node in self.unbounded_vars:
64            # we need to evaluate masked_subblock to recurse, and we need to set indirect values
65            if not isinstance(node.target, str) or (
66                "masked_subblock" not in node.target
67                and "set_indirect" not in node.target
68            ):
69                self._bounds[node] = ValueRanges[Expr].unknown()
70
71        with V.set_ops_handler(ValueRangeAnalysis()):
72            interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules)
73            log.debug("get_bounds:\n%s", self.loop_body.root_block.graph)
74            interpreter.run(V.get_ops_handler(), initial_env=self._bounds)
75        return self._bounds
76
77    def swap_submodules(
78        self, submodules: Dict[str, Callable[..., Any]]
79    ) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
80        result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
81        for key in submodules.keys():
82            if key == "get_index":
83                result[key] = self.get_index
84            elif "masked_subblock" in key:
85                subblock = self.loop_body.subblocks[key]
86                # The result within the lambda will reference to the final
87                # set of modules at the end of the for-loop as it stores a reference to it
88
89                # bind subblock in a function because python lambdas close over by reference
90                # moving the lambda out of make_fn would close over the reference to subblock,
91                # so all lambdas would have the same subblock reference that is the final
92                # subblock in the loop
93                def make_fn(subblock):
94                    return lambda mask, value: self.masked_subblock(
95                        subblock, self._bounds, mask, value, result
96                    )
97
98                result[key] = make_fn(subblock)
99            elif "set_indirect" in key:
100                idx = int(key[len("set_indirect") :])
101                var = self.loop_body.indirect_vars[idx]
102                indirect = partial(self.set_indirect, var)
103                result[key] = indirect
104            else:
105                assert "scan" in key
106                result[key] = submodules[key]
107
108        return result
109
110    def masked_subblock(
111        self,
112        subblock: LoopBodyBlock,
113        env: Dict[torch.fx.Node, ValueRanges[Expr]],
114        mask: Any,
115        value: Any,
116        submodules: Dict[str, Callable[..., Any]],
117    ) -> ValueRanges[Expr]:
118        interp = InterpreterShim(subblock.graph, submodules)
119        interp.run(V.get_ops_handler(), initial_env=env)
120        output = [node for node in subblock.graph.nodes if node.target == "output"]
121        assert len(output) == 1
122        # dont bother unioning with value since the load from buffer will be
123        # pessimistically assumed to be inf anyway
124        return interp.env[output[0]]
125
126    def set_indirect(self, old: Expr, new: ValueRanges[Expr]) -> ValueRanges[Expr]:
127        assert isinstance(new, ValueRanges)
128        self.replacement_vals[old] = new
129        return new
130
131    def get_index(self, name: Expr) -> ValueRanges[Expr]:
132        expr = self.loop_body.indexing_exprs[name]
133        bound = self.replacement_vals.get(expr)
134        if bound is None:
135            bound = bound_sympy(expr, self.replacement_vals)
136        # The following assertion is true at the time of this writing
137        # We don't assert is as to not execute bound_sympy when bound is not None
138        # assert bound is None or bound == bound_sympy(expr, self.replacement_vals)
139        self.replacement_vals[name] = bound
140        return bound
141