xref: /aosp_15_r20/external/pytorch/torch/utils/_sympy/interp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3This is a simple interpreter for Sympy expressions that dispatches to
4classes following the torch._inductor.virtualized calling convention.
5For directness, the interpreter takes the handler directly rather than
6consulting the TLS.  It does not use most of the methods on the full
7handler; only those with corresponding Sympy expressions.  To see an example
8of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis.
9"""
10
11import functools
12import logging
13from typing import Any, Dict, Union
14
15import sympy
16from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom
17
18import torch
19
20from .functions import (
21    CeilToInt,
22    CleanDiv,
23    FloatPow,
24    FloatTrueDiv,
25    FloorDiv,
26    FloorToInt,
27    Identity,
28    IntTrueDiv,
29    IsNonOverlappingAndDenseIndicator,
30    Max,
31    Min,
32    Mod,
33    ModularIndexing,
34    PowByNatural,
35    PythonMod,
36    RoundDecimal,
37    RoundToInt,
38    ToFloat,
39    TruncToFloat,
40    TruncToInt,
41    Where,
42)
43
44
45log = logging.getLogger(__name__)
46
47
48# TODO: Dedupe this with SYMPY_INTERP
49
50
51@functools.lru_cache(None)
52def handlers():
53    # TODO add CeilDiv (it doesn't appear in the index_expr)
54
55    # TODO default to some decompositions if the interpreter doesn't have them
56    # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a)
57
58    HANDLERS = {
59        sympy.Or: "or_",
60        sympy.And: "and_",
61        sympy.Eq: "eq",
62        sympy.Ne: "ne",
63        sympy.Lt: "lt",
64        sympy.Gt: "gt",
65        sympy.Le: "le",
66        sympy.Ge: "ge",
67        sympy.Not: "not_",
68        IntTrueDiv: "int_truediv",
69        FloatTrueDiv: "truediv",
70        FloorDiv: "floordiv",
71        CleanDiv: "floordiv",  # TODO: hmm?
72        TruncToFloat: "trunc",
73        Where: "where",
74        sympy.Add: "add",
75        sympy.Mul: "mul",
76        FloatPow: "pow",
77        PowByNatural: "pow_by_natural",
78        # sympy simplifies x * x into Pow(x, 2), so we need to handle this.
79        # Do NOT use builtin Pow for floats
80        # TODO: There is a hazard here, if we have float * float it will
81        # also get turned into Pow(float, 2) but we don't want this because
82        # pow_by_natural is assumed to only be integers.  Probably the fix is
83        # to add a FloatMul to impede this optimization
84        sympy.Pow: "pow_by_natural",
85        Mod: "mod",
86        PythonMod: "mod",  # TODO: this is wrong
87        # TODO: Inductor can generate these, but it's ill-specified which
88        # semantics were intended here.  Needs to be cleaned up along with
89        # FloorDiv in a bigger cleanup
90        sympy.Mod: "mod",
91        sympy.Abs: "abs",
92        sympy.log: "log",
93        sympy.exp: "exp",
94        sympy.Min: "minimum",
95        sympy.Max: "maximum",
96        Min: "minimum",
97        Max: "maximum",
98        ModularIndexing: "modular_indexing",
99        sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair",
100        sympy.Piecewise: "piecewise",
101        Identity: "identity",
102        IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator",
103        RoundDecimal: "round_decimal",
104    }
105    for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]:
106        HANDLERS[getattr(sympy, name)] = name
107
108    return HANDLERS
109
110
111ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"}
112
113
114def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64):
115    # Special cases
116    if isinstance(expr, sympy.Pow) and isinstance(
117        expr.args[1], sympy.core.numbers.Half
118    ):
119        return analysis.sqrt(args[0])
120    if isinstance(expr, ToFloat):
121        return analysis.to_dtype(args[0], torch.float64)
122
123    # These handlers are special because they take an extra dtype argument
124    # specifying what they should convert to, and we need to appropriately set
125    # this up when we convert from Sympy.  A reasonable default when you
126    # are translating is to conservatively do int64, and then narrow these
127    # arguments later when you discover you can narrow the index range.  But
128    # if you already know that 32-bit indexing is OK, you can directly do the
129    # sympy translation with index_dtype=torch.int32
130    INDEX_DTYPE_HANDLERS = {
131        TruncToInt: "trunc_to_int",
132        sympy.floor: "floor_to_int",
133        sympy.ceiling: "ceil_to_int",
134        FloorToInt: "floor_to_int",
135        CeilToInt: "ceil_to_int",
136        RoundToInt: "round_to_int",
137    }
138    if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None:
139        return getattr(analysis, handler_name)(*args, index_dtype)
140
141    if hasattr(expr.func, "_torch_handler_name"):
142        handler_name = expr.func._torch_handler_name
143    else:
144        handler_name = handlers()[expr.func]
145    handler = getattr(analysis, handler_name)
146    try:
147        if handler_name in ASSOCIATIVE_OPS:
148            assert len(args) > 1
149            acc = handler(args[0], args[1])
150            for i in range(2, len(args)):
151                acc = handler(acc, args[i])
152            log.debug("%s(%s) -> %s", handler_name, args, acc)
153            return acc
154        else:
155            r = handler(*args)
156            log.debug("%s(%s) -> %s", handler_name, args, r)
157            return r
158    except Exception:
159        log.warning("failed while executing %s(%s)", handler_name, args)
160        raise
161
162
163def sympy_interp(
164    analysis,
165    env: Dict[sympy.Symbol, Any],
166    expr: Union[sympy.Expr, SympyBoolean],
167    *,
168    index_dtype=torch.int64,
169):
170    # Handle base cases
171    dtype = None
172    if isinstance(expr, BooleanAtom):
173        dtype = torch.bool
174    elif isinstance(expr, sympy.Integer):
175        dtype = torch.int64
176    elif isinstance(expr, sympy.Number):
177        dtype = torch.double
178
179    if dtype is not None:
180        return analysis.constant(expr, dtype)
181    elif isinstance(expr, sympy.Symbol):
182        return env[expr]
183
184    # Recursive case
185    return _run_sympy_handler(
186        analysis,
187        [sympy_interp(analysis, env, arg) for arg in expr.args],  # type: ignore[arg-type]
188        expr,
189        index_dtype=index_dtype,
190    )  # type: ignore[arg-type]
191