1# mypy: allow-untyped-defs 2"""This file implements the IndexPropagation ops handler, which wraps an 3underlying handler to add a limited form of constant propagation, as well as 4propagation of sympy expressions downstream of ops.index_expr calls. 5 6For example, say we have the IR: 7 8 tmp0 = ops.index_expr(x, torch.int32) 9 tmp1 = ops.constant(2, torch.int32) 10 tmp2 = ops.mul(tmp0, tmp1) 11 tmp3 = ops.indirect_indexing(tmp2, x_size) 12 tmp4 = ops.load("buf0", tmp3) 13 14The underlying handler would just see: 15 16 ops.load("buf0", x * 2) 17 18This is limited by the set of operators handled in the sympy expression 19printers. So simple operations like minimum and maximum cannot be translated to 20SymPy expressions yet, despite sympy.Min and sympy.Max existing. 21 22""" 23import itertools 24from dataclasses import dataclass 25from typing import Any, Callable, Dict, Literal, Optional, overload, Tuple, Union 26from typing_extensions import TypeAlias 27 28import sympy 29 30import torch 31from torch._prims_common import dtype_to_type, is_integer_dtype 32from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Where 33from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges 34 35from .sizevars import evaluate_expr 36from .utils import generate_assert 37from .virtualized import V 38 39 40_ExprType = Union[sympy.Expr, float, int, bool] 41 42 43def _is_constant(val: _ExprType): 44 if isinstance(val, sympy.Basic): 45 return val.is_number 46 return isinstance(val, (int, float, bool)) 47 48 49def upper_bound(val: _ExprType): 50 return bound_sympy(val).upper if isinstance(val, sympy.Expr) else val 51 52 53@dataclass 54class TypedExpr: 55 """A SymPy expression with associated type""" 56 57 expr: _ExprType 58 dtype: torch.dtype 59 60 def is_constant(self): 61 return _is_constant(self.expr) 62 63 def __post_init__(self): 64 if _is_constant(self.expr): 65 self.expr = dtype_to_type(self.dtype)(self.expr) 66 67 68class SymPyOps: 69 """An ops handler where all IR values are SymPy expressions 70 71 When a value cannot be represented as a SymPy expression, the method is 72 either not defined, or returns NotImplemented 73 74 """ 75 76 @staticmethod 77 def identity(value: Any) -> Any: 78 return value 79 80 @staticmethod 81 def constant(value: Union[int, float, bool], dtype: torch.dtype) -> TypedExpr: 82 return TypedExpr(value, dtype) 83 84 @staticmethod 85 def index_expr(value: Union[sympy.Expr, int], dtype: torch.dtype) -> TypedExpr: 86 return TypedExpr(value, dtype) 87 88 @staticmethod 89 def to_dtype( 90 value: TypedExpr, 91 dtype: torch.dtype, 92 src_dtype: Optional[torch.dtype] = None, 93 use_compute_types: bool = False, 94 ) -> TypedExpr: 95 return TypedExpr(value.expr, dtype) 96 97 @staticmethod 98 def abs(x: TypedExpr) -> TypedExpr: 99 return TypedExpr(abs(x.expr), x.dtype) # type: ignore[arg-type] 100 101 @staticmethod 102 def square(x: TypedExpr) -> TypedExpr: 103 return TypedExpr(x.expr * x.expr, x.dtype) 104 105 @staticmethod 106 def add(x: TypedExpr, y: TypedExpr) -> TypedExpr: 107 result_type = torch.promote_types(x.dtype, y.dtype) 108 return TypedExpr(x.expr + y.expr, result_type) 109 110 @staticmethod 111 def sub(x: TypedExpr, y: TypedExpr) -> TypedExpr: 112 result_type = torch.promote_types(x.dtype, y.dtype) 113 return TypedExpr(x.expr - y.expr, result_type) 114 115 @staticmethod 116 def mul(x: TypedExpr, y: TypedExpr) -> TypedExpr: 117 result_type = torch.promote_types(x.dtype, y.dtype) 118 return TypedExpr(x.expr * y.expr, result_type) 119 120 @staticmethod 121 def neg(x: TypedExpr) -> TypedExpr: 122 return TypedExpr(-x.expr, x.dtype) 123 124 @staticmethod 125 def floordiv(x: TypedExpr, y: TypedExpr) -> TypedExpr: 126 result_type = torch.promote_types(x.dtype, y.dtype) 127 if not is_integer_dtype(result_type): 128 return NotImplemented 129 130 return TypedExpr(FloorDiv(x.expr, y.expr), result_type) 131 132 @staticmethod 133 def mod(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: 134 result_type = torch.promote_types(x.dtype, y.dtype) 135 if not is_integer_dtype(result_type): 136 return NotImplemented 137 138 result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) 139 return TypedExpr(result_expr, result_type) 140 141 @staticmethod 142 def remainder(x: TypedExpr, y: TypedExpr) -> Optional[TypedExpr]: 143 result_type = torch.promote_types(x.dtype, y.dtype) 144 if not is_integer_dtype(result_type): 145 return NotImplemented 146 147 x_expr = sympy.sympify(x.expr) 148 y_expr = sympy.sympify(y.expr) 149 # In these cases, remainder in Python == remainder in C++, so this transformation 150 # is sound 151 if ( 152 x_expr.is_nonnegative is not None 153 and x_expr.is_nonnegative == y_expr.is_positive 154 ): 155 result_expr = ModularIndexing(x.expr, sympy.Integer(1), y.expr) 156 return TypedExpr(result_expr, result_type) 157 return NotImplemented 158 159 @staticmethod 160 def minimum(x: TypedExpr, y: TypedExpr) -> TypedExpr: 161 result_type = torch.promote_types(x.dtype, y.dtype) 162 return TypedExpr(sympy.Min(x.expr, y.expr), result_type) 163 164 @staticmethod 165 def maximum(x: TypedExpr, y: TypedExpr) -> TypedExpr: 166 result_type = torch.promote_types(x.dtype, y.dtype) 167 return TypedExpr(sympy.Max(x.expr, y.expr), result_type) 168 169 170@dataclass 171class IndexPropVar: 172 value: Any # Either an IR value, or TypedExpr if is_symbolic is true 173 is_symbolic: bool = False 174 175 @staticmethod 176 def new_symbolic(expr: TypedExpr) -> "IndexPropVar": 177 return IndexPropVar(expr, is_symbolic=True) 178 179 def __post_init__(self): 180 assert not self.is_symbolic or isinstance( 181 self.value, TypedExpr 182 ), "Symbolic IndexPropVar must contain a TypedExpr" 183 184 185IndexPropResult: TypeAlias = Union[IndexPropVar, Tuple["IndexPropResult", ...]] 186 187 188class IndexPropagation: 189 """Ops wrapper that tries to propagate constant and index_expr values through the computation. 190 191 This aims to maximize the compile time simplification possible, and convert 192 indirect indexing from arange into normal static indexing. 193 194 """ 195 196 def __init__( 197 self, 198 inner: Any, 199 iter_ranges: Dict[sympy.Symbol, sympy.Expr], 200 indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr], 201 ) -> None: 202 self._inner = inner 203 self.shape_env = V.graph.sizevars.shape_env 204 205 var_to_range = { 206 k: ValueRanges(0, upper_bound(v) - 1) for k, v in iter_ranges.items() 207 } 208 self.var_to_range = tuple( 209 itertools.chain(self.shape_env.var_to_range.items(), var_to_range.items()) 210 ) 211 # NOTE: this is intentionally kept as a reference so the caller can 212 # update it in-place 213 self.indirect_var_ranges = indirect_var_ranges 214 215 axioms = [] 216 for x, s in iter_ranges.items(): 217 axioms.append(0 <= x) 218 axioms.append(x < s) 219 self.axioms = tuple(axioms) + self.shape_env.get_axioms() 220 221 def materialize_expr(self, expr: sympy.Expr, dtype: torch.dtype) -> Any: 222 # Construct a new constant/index_expr from the SymPy expression 223 if _is_constant(expr): 224 val = dtype_to_type(dtype)(expr) 225 return self._inner.constant(val, dtype) 226 return self._inner.index_expr(expr, dtype) 227 228 def unwrap(self, a: Union[Any, IndexPropVar]) -> Any: 229 if isinstance(a, (list, tuple)): 230 return tuple(self.unwrap(v) for v in a) 231 232 if not isinstance(a, IndexPropVar): 233 return a 234 235 # Prefer the sympy representation if possible 236 if a.is_symbolic: 237 return self.materialize_expr(a.value.expr, a.value.dtype) 238 239 return a.value 240 241 def wrap(self, a) -> IndexPropResult: 242 if isinstance(a, (list, tuple)): 243 return tuple(self.wrap(v) for v in a) 244 return IndexPropVar(a) 245 246 @overload 247 def fallback( 248 self, 249 name: Literal["indirect_indexing"], 250 args: Tuple[Any, ...], 251 kwargs: Dict[str, Any], 252 ) -> IndexPropVar: 253 ... 254 255 @overload 256 def fallback( 257 self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] 258 ) -> IndexPropResult: 259 ... 260 261 def fallback( 262 self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] 263 ) -> IndexPropResult: 264 # Fallback to the wrapped handler 265 new_args = [self.unwrap(a) for a in args] 266 new_kwargs = {k: self.unwrap(v) for k, v in kwargs.items()} 267 return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) 268 269 def propagate_sympy( 270 self, name: str, args: Tuple[Any, ...], kwargs: Dict[str, Any] 271 ) -> IndexPropResult: 272 # Build a new SymPy expression from this ops call 273 def unwrap(a: Union[Any, IndexPropVar]) -> Any: 274 if not isinstance(a, IndexPropVar): 275 return a 276 return a.value 277 278 new_args = [unwrap(a) for a in args] 279 new_kwargs = {k: unwrap(v) for k, v in kwargs.items()} 280 new_expr = getattr(SymPyOps, name)(*new_args, **new_kwargs) 281 is_valid_expr = new_expr is not NotImplemented and ( 282 # Inductor doesn't expect floating point in sympy expressions, but 283 # allow floating point constants to be propagated 284 new_expr.is_constant() 285 or new_expr.expr.is_integer 286 ) 287 if not is_valid_expr: 288 return self.fallback(name, args, kwargs) 289 return IndexPropVar.new_symbolic(new_expr) 290 291 def __getattr__(self, name: str) -> Callable[..., IndexPropResult]: 292 def inner(*args: Any, **kwargs: Any) -> IndexPropResult: 293 if not hasattr(SymPyOps, name): 294 return self.fallback(name, args, kwargs) 295 296 var_arguments = [ 297 a 298 for a in itertools.chain(args, kwargs.values()) 299 if isinstance(a, IndexPropVar) 300 ] 301 if not all(v.is_symbolic for v in var_arguments): 302 return self.fallback(name, args, kwargs) 303 304 return self.propagate_sympy(name, args, kwargs) 305 306 return inner 307 308 def statically_true(self, e): 309 """ 310 Given some iter_ranges, return a function that given an expression, returns whether 311 it is true or false using value ranges, guard knowledge and runtime_asserts. 312 313 FIXME I think this may not be entirely right, as we may not be able to use all runtime_asserts 314 If this is an issue, just use guards in `self.axioms`. 315 316 The proper way of handling this would be to have a global shape_env that adds 317 runtime_asserts as they happen in the code. Then, it shuld be used in SimplifyIndexing 318 to perform wrap_expr and in CSEProxy.check_bounds to elide upper / lower bounds also 319 for indirect_indexing 320 """ 321 var_to_range = ( 322 *self.var_to_range, 323 *( 324 (k, ValueRanges(0, upper_bound(v) - 1)) 325 for k, v in self.indirect_var_ranges.items() 326 ), 327 ) 328 return evaluate_expr(self.shape_env, e, self.axioms, var_to_range) 329 330 def indirect_indexing( 331 self, 332 index: Union[Any, IndexPropVar], 333 size: Any, 334 check: bool = True, 335 wrap_neg=True, 336 ) -> Any: 337 if isinstance(index, IndexPropVar) and index.is_symbolic: 338 # If we find something we can convert into a direct indexing we do so 339 # We still need to (perhaps) wrap the expression and add bound checks 340 # We want to do this "constant folding", as we don't allow to fuse 341 # kernels into indirect indexing 342 343 expr = sympy.sympify(index.value.expr) 344 345 # TODO Perhaps move this logic to the simplify indexing pass 346 def wrap_expr(expr): 347 # Positive, negative, mixed 348 if self.statically_true(0 <= expr): 349 return expr 350 elif self.statically_true(expr < 0): 351 return expr + size 352 else: 353 return Where(expr < 0, expr + size, expr) 354 355 # Sometimes it's easier to prove 0 <= expr than the weaker -size <= expr 356 can_prove_lower = self.statically_true(0 <= expr) or self.statically_true( 357 -size <= expr 358 ) 359 can_prove_upper = self.statically_true(expr < size) 360 if wrap_neg: 361 expr = wrap_expr(expr) 362 if generate_assert(check): 363 self.fallback( 364 "check_bounds", 365 (expr, size), 366 dict(lower=not can_prove_lower, upper=not can_prove_upper), 367 ) 368 return expr 369 370 indirect_var = self.fallback( 371 "indirect_indexing", (index, size, check, wrap_neg), {} 372 ).value 373 return indirect_var 374