1# mypy: allow-untyped-defs 2""" 3This is a simple interpreter for Sympy expressions that dispatches to 4classes following the torch._inductor.virtualized calling convention. 5For directness, the interpreter takes the handler directly rather than 6consulting the TLS. It does not use most of the methods on the full 7handler; only those with corresponding Sympy expressions. To see an example 8of a full handler, see torch.utils._sympy.value_ranges.ValueRangeAnalysis. 9""" 10 11import functools 12import logging 13from typing import Any, Dict, Union 14 15import sympy 16from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom 17 18import torch 19 20from .functions import ( 21 CeilToInt, 22 CleanDiv, 23 FloatPow, 24 FloatTrueDiv, 25 FloorDiv, 26 FloorToInt, 27 Identity, 28 IntTrueDiv, 29 IsNonOverlappingAndDenseIndicator, 30 Max, 31 Min, 32 Mod, 33 ModularIndexing, 34 PowByNatural, 35 PythonMod, 36 RoundDecimal, 37 RoundToInt, 38 ToFloat, 39 TruncToFloat, 40 TruncToInt, 41 Where, 42) 43 44 45log = logging.getLogger(__name__) 46 47 48# TODO: Dedupe this with SYMPY_INTERP 49 50 51@functools.lru_cache(None) 52def handlers(): 53 # TODO add CeilDiv (it doesn't appear in the index_expr) 54 55 # TODO default to some decompositions if the interpreter doesn't have them 56 # like decomposing ModularIndexing or implementing Le(a,b) as Ge(b, a) 57 58 HANDLERS = { 59 sympy.Or: "or_", 60 sympy.And: "and_", 61 sympy.Eq: "eq", 62 sympy.Ne: "ne", 63 sympy.Lt: "lt", 64 sympy.Gt: "gt", 65 sympy.Le: "le", 66 sympy.Ge: "ge", 67 sympy.Not: "not_", 68 IntTrueDiv: "int_truediv", 69 FloatTrueDiv: "truediv", 70 FloorDiv: "floordiv", 71 CleanDiv: "floordiv", # TODO: hmm? 72 TruncToFloat: "trunc", 73 Where: "where", 74 sympy.Add: "add", 75 sympy.Mul: "mul", 76 FloatPow: "pow", 77 PowByNatural: "pow_by_natural", 78 # sympy simplifies x * x into Pow(x, 2), so we need to handle this. 79 # Do NOT use builtin Pow for floats 80 # TODO: There is a hazard here, if we have float * float it will 81 # also get turned into Pow(float, 2) but we don't want this because 82 # pow_by_natural is assumed to only be integers. Probably the fix is 83 # to add a FloatMul to impede this optimization 84 sympy.Pow: "pow_by_natural", 85 Mod: "mod", 86 PythonMod: "mod", # TODO: this is wrong 87 # TODO: Inductor can generate these, but it's ill-specified which 88 # semantics were intended here. Needs to be cleaned up along with 89 # FloorDiv in a bigger cleanup 90 sympy.Mod: "mod", 91 sympy.Abs: "abs", 92 sympy.log: "log", 93 sympy.exp: "exp", 94 sympy.Min: "minimum", 95 sympy.Max: "maximum", 96 Min: "minimum", 97 Max: "maximum", 98 ModularIndexing: "modular_indexing", 99 sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", 100 sympy.Piecewise: "piecewise", 101 Identity: "identity", 102 IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", 103 RoundDecimal: "round_decimal", 104 } 105 for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: 106 HANDLERS[getattr(sympy, name)] = name 107 108 return HANDLERS 109 110 111ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"} 112 113 114def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64): 115 # Special cases 116 if isinstance(expr, sympy.Pow) and isinstance( 117 expr.args[1], sympy.core.numbers.Half 118 ): 119 return analysis.sqrt(args[0]) 120 if isinstance(expr, ToFloat): 121 return analysis.to_dtype(args[0], torch.float64) 122 123 # These handlers are special because they take an extra dtype argument 124 # specifying what they should convert to, and we need to appropriately set 125 # this up when we convert from Sympy. A reasonable default when you 126 # are translating is to conservatively do int64, and then narrow these 127 # arguments later when you discover you can narrow the index range. But 128 # if you already know that 32-bit indexing is OK, you can directly do the 129 # sympy translation with index_dtype=torch.int32 130 INDEX_DTYPE_HANDLERS = { 131 TruncToInt: "trunc_to_int", 132 sympy.floor: "floor_to_int", 133 sympy.ceiling: "ceil_to_int", 134 FloorToInt: "floor_to_int", 135 CeilToInt: "ceil_to_int", 136 RoundToInt: "round_to_int", 137 } 138 if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: 139 return getattr(analysis, handler_name)(*args, index_dtype) 140 141 if hasattr(expr.func, "_torch_handler_name"): 142 handler_name = expr.func._torch_handler_name 143 else: 144 handler_name = handlers()[expr.func] 145 handler = getattr(analysis, handler_name) 146 try: 147 if handler_name in ASSOCIATIVE_OPS: 148 assert len(args) > 1 149 acc = handler(args[0], args[1]) 150 for i in range(2, len(args)): 151 acc = handler(acc, args[i]) 152 log.debug("%s(%s) -> %s", handler_name, args, acc) 153 return acc 154 else: 155 r = handler(*args) 156 log.debug("%s(%s) -> %s", handler_name, args, r) 157 return r 158 except Exception: 159 log.warning("failed while executing %s(%s)", handler_name, args) 160 raise 161 162 163def sympy_interp( 164 analysis, 165 env: Dict[sympy.Symbol, Any], 166 expr: Union[sympy.Expr, SympyBoolean], 167 *, 168 index_dtype=torch.int64, 169): 170 # Handle base cases 171 dtype = None 172 if isinstance(expr, BooleanAtom): 173 dtype = torch.bool 174 elif isinstance(expr, sympy.Integer): 175 dtype = torch.int64 176 elif isinstance(expr, sympy.Number): 177 dtype = torch.double 178 179 if dtype is not None: 180 return analysis.constant(expr, dtype) 181 elif isinstance(expr, sympy.Symbol): 182 return env[expr] 183 184 # Recursive case 185 return _run_sympy_handler( 186 analysis, 187 [sympy_interp(analysis, env, arg) for arg in expr.args], # type: ignore[arg-type] 188 expr, 189 index_dtype=index_dtype, 190 ) # type: ignore[arg-type] 191