xref: /aosp_15_r20/external/pytorch/torch/_inductor/index_propagation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""This file implements the IndexPropagation ops handler, which wraps an
3underlying handler to add a limited form of constant propagation, as well as
4propagation of sympy expressions downstream of ops.index_expr calls.
5
6For example, say we have the IR:
7
8   tmp0 = ops.index_expr(x, torch.int32)
9   tmp1 = ops.constant(2, torch.int32)
10   tmp2 = ops.mul(tmp0, tmp1)
11   tmp3 = ops.indirect_indexing(tmp2, x_size)
12   tmp4 = ops.load("buf0", tmp3)
13
14The underlying handler would just see:
15
16   ops.load("buf0", x * 2)
17
18This is limited by the set of operators handled in the sympy expression
19printers. So simple operations like minimum and maximum cannot be translated to
20SymPy expressions yet, despite sympy.Min and sympy.Max existing.
21
22"""
23import itertools
24from dataclasses import dataclass
25from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union
26from typing_extensions import TypeAlias
27
28import sympy
29
30import torch
31from torch._prims_common import dtype_to_type, is_integer_dtype
32from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where
33from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
34
35from .sizevars import evaluate_expr
36from .utils import generate_assert
37from .virtualized import V
38
39
40_ExprType = Union[sympy.Expr, float, int, bool]
41
42
43def _is_constant(val: _ExprType):
44    if isinstance(val, sympy.Basic):
45        return val.is_number
46    return isinstance(val, (int, float, bool))
47
48
49def upper_bound(val: _ExprType):
50    return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val
51
52
53@dataclass
54class TypedExpr:
55    """A SymPy expression with associated type"""
56
57    expr: _ExprType
58    dtype: torch.dtype
59
60    def is_constant(self):
61        return _is_constant(self.expr)
62
63    def __post_init__(self):
64        if _is_constant(self.expr):
65            self.expr = dtype_to_type(self.dtype)(self.expr)
66
67
68class SymPyOps:
69    """An ops handler where all IR values are SymPy expressions
70
71    When a value cannot be represented as a SymPy expression, the method is
72    either not defined, or returns NotImplemented
73
74    """
75
76    @staticmethod
77    def identity(value: Any) -> Any:
78        return value
79
80    @staticmethod
81    def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr:
82        return TypedExpr(value, dtype)
83
84    @staticmethod
85    def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr:
86        return TypedExpr(value, dtype)
87
88    @staticmethod
89    def to_dtype(
90        value: TypedExpr,
91        dtype: torch.dtype,
92        src_dtype: Optional[torch.dtype] = None,
93        use_compute_types: bool = False,
94    ) -> TypedExpr:
95        return TypedExpr(value.expr, dtype)
96
97    @staticmethod
98    def abs(x: TypedExpr) -> TypedExpr:
99        return TypedExpr(abs(x.expr), x.dtype)  # type: ignore[arg-type]
100
101    @staticmethod
102    def square(x: TypedExpr) -> TypedExpr:
103        return TypedExpr(x.expr * x.expr, x.dtype)
104
105    @staticmethod
106    def add(x: TypedExpr, y: TypedExpr) -> TypedExpr:
107        result_type = torch.promote_types(x.dtype, y.dtype)
108        return TypedExpr(x.expr + y.expr, result_type)
109
110    @staticmethod
111    def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr:
112        result_type = torch.promote_types(x.dtype, y.dtype)
113        return TypedExpr(x.expr - y.expr, result_type)
114
115    @staticmethod
116    def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr:
117        result_type = torch.promote_types(x.dtype, y.dtype)
118        return TypedExpr(x.expr * y.expr, result_type)
119
120    @staticmethod
121    def neg(x: TypedExpr) -> TypedExpr:
122        return TypedExpr(-x.expr, x.dtype)
123
124    @staticmethod
125    def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr:
126        result_type = torch.promote_types(x.dtype, y.dtype)
127        if not is_integer_dtype(result_type):
128            return NotImplemented
129
130        return TypedExpr(FloorDiv(x.expr, y.expr), result_type)
131
132    @staticmethod
133    def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
134        result_type = torch.promote_types(x.dtype, y.dtype)
135        if not is_integer_dtype(result_type):
136            return NotImplemented
137
138        result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
139        return TypedExpr(result_expr, result_type)
140
141    @staticmethod
142    def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]:
143        result_type = torch.promote_types(x.dtype, y.dtype)
144        if not is_integer_dtype(result_type):
145            return NotImplemented
146
147        x_expr = sympy.sympify(x.expr)
148        y_expr = sympy.sympify(y.expr)
149        # In these cases, remainder in Python == remainder in C++, so this transformation
150        # is sound
151        if (
152            x_expr.is_nonnegative is not None
153            and x_expr.is_nonnegative == y_expr.is_positive
154        ):
155            result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr)
156            return TypedExpr(result_expr, result_type)
157        return NotImplemented
158
159    @staticmethod
160    def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
161        result_type = torch.promote_types(x.dtype, y.dtype)
162        return TypedExpr(sympy.Min(x.expr, y.expr), result_type)
163
164    @staticmethod
165    def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr:
166        result_type = torch.promote_types(x.dtype, y.dtype)
167        return TypedExpr(sympy.Max(x.expr, y.expr), result_type)
168
169
170@dataclass
171class IndexPropVar:
172    value: Any  # Either an IR value, or TypedExpr if is_symbolic is true
173    is_symbolic: bool = False
174
175    @staticmethod
176    def new_symbolic(expr: TypedExpr) -> "IndexPropVar":
177        return IndexPropVar(expr, is_symbolic=True)
178
179    def __post_init__(self):
180        assert not self.is_symbolic or isinstance(
181            self.value, TypedExpr
182        ), "Symbolic IndexPropVar must contain a TypedExpr"
183
184
185IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]]
186
187
188class IndexPropagation:
189    """Ops wrapper that tries to propagate constant and index_expr values through the computation.
190
191    This aims to maximize the compile time simplification possible, and convert
192    indirect indexing from arange into normal static indexing.
193
194    """
195
196    def __init__(
197        self,
198        inner: Any,
199        iter_ranges: Dict[sympy.Symbol, sympy.Expr],
200        indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
201    ) -> None:
202        self._inner = inner
203        self.shape_env = V.graph.sizevars.shape_env
204
205        var_to_range = {
206            k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items()
207        }
208        self.var_to_range = tuple(
209            itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items())
210        )
211        # NOTE: this is intentionally kept as a reference so the caller can
212        # update it in-place
213        self.indirect_var_ranges = indirect_var_ranges
214
215        axioms = []
216        for x, s in iter_ranges.items():
217            axioms.append(0 <= x)
218            axioms.append(x < s)
219        self.axioms = tuple(axioms) + self.shape_env.get_axioms()
220
221    def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any:
222        # Construct a new constant/index_expr from the SymPy expression
223        if _is_constant(expr):
224            val = dtype_to_type(dtype)(expr)
225            return self._inner.constant(val, dtype)
226        return self._inner.index_expr(expr, dtype)
227
228    def unwrap(self, a: Union[Any, IndexPropVar]) -> Any:
229        if isinstance(a, (list, tuple)):
230            return tuple(self.unwrap(v) for v in a)
231
232        if not isinstance(a, IndexPropVar):
233            return a
234
235        # Prefer the sympy representation if possible
236        if a.is_symbolic:
237            return self.materialize_expr(a.value.expr, a.value.dtype)
238
239        return a.value
240
241    def wrap(self, a) -> IndexPropResult:
242        if isinstance(a, (list, tuple)):
243            return tuple(self.wrap(v) for v in a)
244        return IndexPropVar(a)
245
246    @overload
247    def fallback(
248        self,
249        name: Literal["indirect_indexing"],
250        args: Tuple[Any, ...],
251        kwargs: Dict[str, Any],
252    ) -> IndexPropVar:
253        ...
254
255    @overload
256    def fallback(
257        self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
258    ) -> IndexPropResult:
259        ...
260
261    def fallback(
262        self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
263    ) -> IndexPropResult:
264        # Fallback to the wrapped handler
265        new_args = [self.unwrap(a) for a in args]
266        new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()}
267        return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
268
269    def propagate_sympy(
270        self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any]
271    ) -> IndexPropResult:
272        # Build a new SymPy expression from this ops call
273        def unwrap(a: Union[Any, IndexPropVar]) -> Any:
274            if not isinstance(a, IndexPropVar):
275                return a
276            return a.value
277
278        new_args = [unwrap(a) for a in args]
279        new_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
280        new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs)
281        is_valid_expr = new_expr is not NotImplemented and (
282            # Inductor doesn't expect floating point in sympy expressions, but
283            # allow floating point constants to be propagated
284            new_expr.is_constant()
285            or new_expr.expr.is_integer
286        )
287        if not is_valid_expr:
288            return self.fallback(name, args, kwargs)
289        return IndexPropVar.new_symbolic(new_expr)
290
291    def __getattr__(self, name: str) -> Callable[..., IndexPropResult]:
292        def inner(*args: Any, **kwargs: Any) -> IndexPropResult:
293            if not hasattr(SymPyOps, name):
294                return self.fallback(name, args, kwargs)
295
296            var_arguments = [
297                a
298                for a in itertools.chain(args, kwargs.values())
299                if isinstance(a, IndexPropVar)
300            ]
301            if not all(v.is_symbolic for v in var_arguments):
302                return self.fallback(name, args, kwargs)
303
304            return self.propagate_sympy(name, args, kwargs)
305
306        return inner
307
308    def statically_true(self, e):
309        """
310        Given some iter_ranges, return a function that given an expression, returns whether
311        it is true or false using value ranges, guard knowledge and runtime_asserts.
312
313        FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts
314              If this is an issue, just use guards in `self.axioms`.
315
316              The proper way of handling this would be to have a global shape_env that adds
317              runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing
318              to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also
319              for indirect_indexing
320        """
321        var_to_range = (
322            *self.var_to_range,
323            *(
324                (k, ValueRanges(0, upper_bound(v) - 1))
325                for k, v in self.indirect_var_ranges.items()
326            ),
327        )
328        return evaluate_expr(self.shape_env, e, self.axioms, var_to_range)
329
330    def indirect_indexing(
331        self,
332        index: Union[Any, IndexPropVar],
333        size: Any,
334        check: bool = True,
335        wrap_neg=True,
336    ) -> Any:
337        if isinstance(index, IndexPropVar) and index.is_symbolic:
338            # If we find something we can convert into a direct indexing we do so
339            # We still need to (perhaps) wrap the expression and add bound checks
340            # We want to do this "constant folding", as we don't allow to fuse
341            # kernels into indirect indexing
342
343            expr = sympy.sympify(index.value.expr)
344
345            # TODO Perhaps move this logic to the simplify indexing pass
346            def wrap_expr(expr):
347                # Positive, negative, mixed
348                if self.statically_true(0 <= expr):
349                    return expr
350                elif self.statically_true(expr < 0):
351                    return expr + size
352                else:
353                    return Where(expr < 0, expr + size, expr)
354
355            # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr
356            can_prove_lower = self.statically_true(0 <= expr) or self.statically_true(
357                -size <= expr
358            )
359            can_prove_upper = self.statically_true(expr < size)
360            if wrap_neg:
361                expr = wrap_expr(expr)
362            if generate_assert(check):
363                self.fallback(
364                    "check_bounds",
365                    (expr, size),
366                    dict(lower=not can_prove_lower, upper=not can_prove_upper),
367                )
368            return expr
369
370        indirect_var = self.fallback(
371            "indirect_indexing", (index, size, check, wrap_neg), {}
372        ).value
373        return indirect_var
374