1# mypy: allow-untyped-defs 2import functools 3import logging 4import operator 5import sys 6from typing import Any, Dict, Optional, Set, TYPE_CHECKING 7 8 9# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow 10if TYPE_CHECKING: 11 import sympy 12 13 from torch.fx.experimental.symbolic_shapes import ShapeEnv 14else: 15 ShapeEnv = Any 16 17import torch 18import torch.utils._pytree as pytree 19from torch import fx 20from torch._subclasses.meta_utils import is_sparse_any 21from torch.fx._compatibility import compatibility 22from torch.fx._utils import lazy_format_graph_code 23from torch.fx.experimental.proxy_tensor import py_sym_types 24from torch.fx.experimental.sym_node import SymNode 25from torch.fx.graph_module import GraphModule 26 27 28__all__ = ["insert_deferred_runtime_asserts"] 29 30log = logging.getLogger(__name__) 31graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") 32 33 34def _get_example_value(node: fx.Node) -> Optional[str]: 35 """ 36 Get the example value key for a node, since dynamo uses "example_value" 37 while non-strict export uses "val. 38 """ 39 if "example_value" in node.meta: 40 return node.meta["example_value"] 41 elif "val" in node.meta: 42 return node.meta["val"] 43 else: 44 return None 45 46 47def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]: 48 val = _get_example_value(node) 49 if isinstance(val, py_sym_types): 50 return val.node.expr 51 return None 52 53 54@compatibility(is_backward_compatible=True) 55def insert_deferred_runtime_asserts( 56 gm: GraphModule, 57 shape_env: ShapeEnv, 58 name: str, 59 export: bool = False, 60) -> None: 61 """ 62 During tracing, we may have discovered that some data-dependent values 63 had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime 64 that x.item() >= 0. This asserts can happen unpredictably during fake 65 tensor propagation, so we cannot conveniently insert them into the FX graph 66 when they occur. Instead, we accumulate them in the ShapeEnv, and in this 67 pass insert them into the graph as proper tests. 68 69 This pass also deduplicates size-related computation, CSE-ing ops that produce 70 symbolic values and/or are involved in runtime asserts. Additionally, shape calls 71 (size/stride/storage_offset) are turned into compute on input sizes if possible, 72 allowing intermediate tensors to be freed earlier. For example, here dynamo will 73 DCE the cat and repeat calls: 74 75 z = torch.cat([x, x], dim=0) # 2*s0 76 w = z.repeat(y.shape[0]) # 2*s0*s1 77 _w = w.shape[0] 78 # something with _w, but not w ... 79 80 # turns into -> 81 _w0 = 2 * s0 82 _w = _w0 * s1 83 84 # where s0, s1 are either SymInt graph inputs, or the result of added size calls 85 86 Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert 87 the same expression, and redundant constrain_range calls are also deduplicated. 88 Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate 89 information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol, 90 and we delete all previous calls, adding bound checks at the end of this pass. 91 """ 92 93 # Import sympy locally 94 import sympy 95 96 from torch._export.passes._node_metadata_hook import _set_node_metadata_hook 97 from torch.fx.experimental.symbolic_shapes import ( 98 _has_uninterpretable_sympy_function, 99 CallMethodKey, 100 cast_symbool_to_symint_guardless, 101 ConvertIntKey, 102 DivideByKey, 103 free_symbols, 104 InnerTensorKey, 105 resolve_unbacked_bindings, 106 ) 107 from torch.utils._sympy.numbers import int_oo 108 from torch.utils._sympy.reference import PythonReferenceAnalysis 109 from torch.utils._sympy.value_ranges import ValueRanges 110 111 # TODO: Request simplification on runtime asserts before emitting them 112 ras_by_symbol = shape_env.deferred_runtime_asserts.copy() 113 graph = gm.graph 114 graph_code_log.debug( 115 "%s", 116 lazy_format_graph_code( 117 f"pre insert_deferred_runtime_asserts {name}", gm, colored=True 118 ), 119 ) 120 121 # We are going to mutate the dict 122 expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {} 123 placeholders = set() 124 first_non_placeholder = None 125 for node in graph.nodes: 126 if node.op != "placeholder": 127 first_non_placeholder = node 128 break 129 else: 130 placeholders.add(node) 131 132 def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: 133 """ 134 If a size/stride/storage offset call on an intermediate tensor, 135 we can try to compute the value from input shapes instead. 136 """ 137 return ( 138 (val := _get_sym_val(node)) is not None 139 and not isinstance(val, sympy.Number) 140 # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported 141 and not _has_uninterpretable_sympy_function(val) 142 and any( 143 isinstance(arg, fx.Node) 144 and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) 145 and arg.op != "placeholder" 146 for arg in node.args 147 ) 148 ) 149 150 # Figure out what key to use, val or example_value 151 val_key = "val" 152 for node in graph.nodes: 153 if "example_value" in node.meta: 154 val_key = "example_value" 155 break 156 elif "val" in node.meta: 157 break 158 159 def _node_metadata_hook( 160 node: torch.fx.Node, 161 stack_trace: Optional[str] = None, 162 nn_module_stack: Optional[Dict[str, Any]] = None, 163 ) -> None: 164 fake_args = [ 165 _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg 166 for arg in node.args 167 ] 168 try: 169 node.meta[val_key] = node.target(*fake_args) # type: ignore[operator] 170 except NotImplementedError: 171 # This can happen when attempting to reify a symbol with an unsupported call_function node, 172 # e.g. with NestedTensors + sym_size.int via match_symbol(). 173 # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input. 174 pass 175 if stack_trace is not None: 176 node.meta["stack_trace"] = stack_trace 177 if nn_module_stack is not None: 178 node.meta["nn_module_stack"] = nn_module_stack 179 180 # Track asserts/checks we've added 181 added_asserts: Set[sympy.Expr] = set() 182 constrained_unbacked_symbols: Set[sympy.Symbol] = set() 183 184 def _sympy_interp(expr_to_proxy, expr): 185 # sympy_interp() with hash consing 186 from sympy import Integer, Number, Symbol 187 from sympy.logic.boolalg import BooleanAtom 188 189 from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp 190 191 # hash cons 192 if expr in expr_to_proxy: 193 return expr_to_proxy[expr] 194 # base cases, don't cache 195 if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)): 196 return sympy_interp(PythonReferenceAnalysis, expr_to_proxy, expr) 197 198 # hash cons on arguments, run expr handler 199 expr_to_proxy[expr] = _run_sympy_handler( 200 PythonReferenceAnalysis, 201 [_sympy_interp(expr_to_proxy, arg) for arg in expr.args], 202 expr, 203 ) 204 return expr_to_proxy[expr] 205 206 def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool: 207 # This is probably unnecessary, but since torch._check() calls for single-symbol bounds 208 # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant 209 # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial. 210 if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan): 211 return False 212 lhs, rhs = expr.args 213 return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or ( 214 isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number) 215 ) 216 217 def add_runtime_asserts(ras): 218 for ra in ras: 219 if ( 220 # redundant 221 ra.expr in added_asserts 222 # if we've already added a constrain_range call for this symbol, 223 # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant. 224 or ( 225 len(ra.expr.free_symbols) == 1 226 and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols 227 and _is_bound_expr_for_symbol(ra.expr) 228 ) 229 # don't try to reify sympy functions we can't turn into FX nodes 230 or _has_uninterpretable_sympy_function(ra.expr) 231 ): 232 continue 233 234 log.debug("inserting runtime assert %s", ra.expr) 235 # Need to process ALL free symbols, not just unbacked ones 236 fvs = free_symbols(ra.expr) 237 missing = fvs - expr_to_proxy.keys() 238 if missing: 239 i1 = min(missing, key=str) 240 # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689 241 # assert shape_env.is_unbacked_symint(i1), i1 242 ras_by_symbol.setdefault(i1, []).append(ra) 243 else: 244 # Convert the sympy expression into a sequence of FX 245 # nodes 246 with _set_node_metadata_hook(gm, _node_metadata_hook): 247 res = _sympy_interp(expr_to_proxy, ra.expr).node 248 graph.call_function( 249 torch.ops.aten._assert_scalar.default, 250 # TODO: use ra.msg here, but it's pretty 251 # useless right now 252 ( 253 res, 254 f"Runtime assertion failed for expression {ra.expr} on node '{res}'", 255 ), 256 ) 257 added_asserts.add(ra.expr) 258 259 nodes = list(graph.nodes) 260 for i, node in enumerate(nodes[:-1]): 261 # Placeholders can match symbols, but when we destructure them 262 # with size we have to make sure we insert the nodes after all 263 # the placeholders 264 with graph.inserting_before( 265 nodes[i + 1] if node not in placeholders else first_non_placeholder 266 ): 267 # Unfortunately, this logic still must remain because manual 268 # make_fx calls may not explicitly bind all symbolic ints as 269 # arguments to the function, so we must infer it from the other 270 # arguments 271 if ( 272 node in placeholders 273 and (example_value := _get_example_value(node)) is not None 274 ): 275 276 def match_symbol(symint, cb): 277 if ( 278 isinstance(symint, torch.SymInt) 279 and isinstance(symint.node, SymNode) 280 and isinstance(s := symint.node.expr, sympy.Symbol) 281 and s not in expr_to_proxy 282 ): 283 with _set_node_metadata_hook(gm, _node_metadata_hook): 284 expr_to_proxy[s] = fx.Proxy(cb()) 285 log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) 286 287 match_symbol(example_value, lambda: node) 288 if isinstance(t := example_value, torch.Tensor): 289 for i, s in enumerate(t.size()): 290 match_symbol( 291 s, 292 lambda: graph.call_function( 293 torch.ops.aten.sym_size.int, (node, i) 294 ), 295 ) 296 if not is_sparse_any(t): 297 for i, s in enumerate(t.stride()): 298 match_symbol( 299 s, 300 lambda: graph.call_function( 301 torch.ops.aten.sym_stride.int, (node, i) 302 ), 303 ) 304 match_symbol( 305 t.storage_offset(), 306 lambda: graph.call_function( 307 torch.ops.aten.sym_storage_offset.default, (node,) 308 ), 309 ) 310 311 # Handle asserts that aren't associated with any symbol. This 312 # doesn't really have to be in the loop as it will only run once, 313 # it just needs to happen right after the placeholders. 314 # insert this after placeholders & added sym nodes, and before non-placeholders. 315 if node == first_non_placeholder: 316 add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] 317 318 # deduplicate asserts already present in graph 319 if node.target in ( 320 torch._check, 321 torch.ops.aten._assert_scalar.default, 322 ): 323 if ( 324 node.args[0] == True # noqa: E712 325 or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy 326 or ( 327 assert_expr is not None 328 and _is_bound_expr_for_symbol(assert_expr) 329 ) 330 ): 331 arg = node.args[0] 332 gm.graph.erase_node(node) 333 if isinstance(arg, fx.Node) and not arg.users: 334 gm.graph.erase_node(arg) 335 else: 336 added_asserts.add(assert_expr) # type: ignore[arg-type] 337 338 # hash cons, replace function calls that return torch.SymInts with direct references to 339 # FX nodes built up to reify the sympy expression. 340 if ( 341 node.op != "placeholder" 342 and (sym_expr := _get_sym_val(node)) is not None 343 ): 344 # this guards against deleting calls like item() that produce new untracked symbols 345 new_untracked_symbols = sym_expr.free_symbols - expr_to_proxy.keys() 346 # this guards against deleting calls that produce unbacked bindings we haven't yet seen. 347 # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint 348 # (is backed), but produces an unbacked symbol. In this case keep the node alive. 349 new_unbacked_bindings = ( 350 resolve_unbacked_bindings( 351 shape_env, node.meta.get("unbacked_bindings", {}) 352 ).keys() 353 - expr_to_proxy.keys() 354 ) 355 356 # maybe re-reify expression, replace current node 357 if ( 358 sym_expr in expr_to_proxy 359 or ( # example value is redundant 360 _is_intermediate_tensor_sym_call(node) 361 # shape call on intermediate tensor, turn into computation on input shapes 362 and not new_untracked_symbols 363 ) 364 ) and not new_unbacked_bindings: 365 if _is_intermediate_tensor_sym_call( 366 node 367 ): # reify from input shapes 368 with _set_node_metadata_hook( 369 gm, 370 functools.partial( 371 _node_metadata_hook, 372 stack_trace=node.meta.get("stack_trace"), 373 nn_module_stack=node.meta.get("nn_module_stack"), 374 ), 375 ): 376 expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] 377 # won't try DCE-ing tensor compute here 378 hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] 379 node.replace_all_uses_with(hash_node) 380 gm.graph.erase_node(node) 381 log.debug( 382 "CSE node %s -> %s for expr %s", node, hash_node, sym_expr 383 ) 384 385 # store node in hash cons, don't delete/replace 386 elif sym_expr not in expr_to_proxy and not isinstance( 387 sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) 388 ): # don't hash cons primitives 389 expr_to_proxy[sym_expr] = fx.Proxy(node) # type: ignore[arg-type] 390 391 # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained, 392 # so calls before that are redundant. 393 if node.target in ( 394 torch.ops.aten.sym_constrain_range.default, 395 torch.ops.aten.sym_constrain_range_for_size.default, 396 ): 397 gm.graph.erase_node(node) 398 399 defs = [] 400 401 # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as 402 # equivalent, but the refinement calls we perform in this pass may struggle with associating the two. 403 # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough 404 # information about the old symbol when we re-export, raising errors on data-dependent guards. 405 # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is. 406 if unbacked_bindings := resolve_unbacked_bindings( 407 shape_env, node.meta.get("unbacked_bindings") 408 ): 409 for s, keypath in unbacked_bindings.items(): 410 defs.append(s) 411 412 # TODO: some CSE when generating these nodes can probably 413 # help reduce graph size and improve compile time 414 def go(node, keypath): 415 if keypath == (): 416 return node 417 if ( 418 len(keypath) >= 2 419 and isinstance(keypath[0], CallMethodKey) 420 and isinstance(keypath[1], pytree.SequenceKey) 421 ): 422 if keypath[0].name == "size": 423 return go( 424 graph.call_function( 425 torch.ops.aten.sym_size.int, 426 (node, keypath[1].idx), 427 ), 428 keypath[2:], 429 ) 430 if keypath[0].name == "stride": 431 return go( 432 graph.call_function( 433 torch.ops.aten.sym_stride.int, 434 (node, keypath[1].idx), 435 ), 436 keypath[2:], 437 ) 438 return go( 439 graph.call_method( 440 keypath[0].name, (node, keypath[1].idx) 441 ), 442 keypath[2:], 443 ) 444 elif isinstance(keypath[0], CallMethodKey): 445 return go( 446 graph.call_method(keypath[0].name, (node,)), keypath[1:] 447 ) 448 elif isinstance(keypath[0], pytree.SequenceKey): 449 return go( 450 graph.call_function( 451 operator.getitem, (node, keypath[0].idx) 452 ), 453 keypath[1:], 454 ) 455 elif isinstance(keypath[0], ConvertIntKey): 456 return go( 457 graph.call_function( 458 cast_symbool_to_symint_guardless, (node,) 459 ), 460 keypath[1:], 461 ) 462 elif isinstance(keypath[0], DivideByKey): 463 # TODO: need to assert divisibility 464 return go( 465 graph.call_function( 466 operator.floordiv, (node, keypath[0].divisor) 467 ), 468 keypath[1:], 469 ) 470 elif isinstance(keypath[0], InnerTensorKey): 471 return go( 472 graph.call_function( 473 getattr, (node, keypath[0].inner_name) 474 ), 475 keypath[1:], 476 ) 477 else: 478 raise AssertionError(f"unrecognized keypath {keypath}") 479 480 if s not in expr_to_proxy: 481 with _set_node_metadata_hook(gm, _node_metadata_hook): 482 expr_to_proxy[s] = fx.Proxy(go(node, keypath)) 483 log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) 484 485 for i0 in defs: 486 ras = ras_by_symbol.pop(i0, []) 487 # Before we perform any asserts, first apply range 488 # refinement. This is important, because if we are going 489 # to retrace the graph (and we typically are if we send 490 # the graph to AOTAutograd), we need to make sure we apply 491 # range refinement (ala _check_is_size) first, BEFORE we 492 # run any of the asserts. Otherwise, we may decide to 493 # perform substitutions based on the asserts which we then 494 # can't back out, because value ranges can only be applied 495 # to asserts.) 496 # 497 # A perhaps better long term plan is to avoid this order 498 # dependence by making it possible to refine ranges on 499 # arbitrary expressions, not just symbols. But it is not 500 # so easy to make use of this information, see 501 # https://twitter.com/ezyang/status/1745801370299482492 502 # We actually made an attempt at this in 503 # https://github.com/pytorch/pytorch/pull/119043 504 # which didn't work. 505 # 506 # Another ideas for how to do this: 507 # - Have bound_sympy be the source of truth of the ranges of any expression 508 # - Cache intermediate results for every subexpression of bound_sympy 509 # - This cache should be possible to edit to refine ranges 510 # 511 # One issue with this proposal is that if 512 # we have a bound on 2x, we are not going to be able to 513 # apply it for 4x. Similarly, we may have bounds for an 514 # equivalent expression that we are not applying because 515 # it's not a perfect match (e.g. x < y vs y > x)". 516 # 517 # The first issue we already have it and it's impossible 518 # to solve in general, so any implementation on a best 519 # effort basis should do. 520 # 521 # The second issue is a preexisting one. It can be mitigated 522 # with a normalisation algorithm. In general, it may also 523 # be on a best effort basis, but since our grammar is not 524 # terribly difficult, chances are we could even fully 525 # normalise SymPy expressions... who knows. 526 if i0 in constrained_unbacked_symbols: 527 continue # constrain symbol just once 528 529 if i0 in shape_env.size_like: 530 if export: 531 graph.call_function( 532 torch.ops.aten.sym_constrain_range_for_size.default, 533 (expr_to_proxy[i0].node,), 534 ) 535 else: 536 graph.call_function( 537 torch._check_is_size, (expr_to_proxy[i0].node,) 538 ) 539 540 vr = shape_env.var_to_range[i0] 541 if vr.is_int and vr.upper == sys.maxsize - 1: 542 # treat upper bound == sys.maxsize - 1 for int symbols as +oo 543 # to avoid redundant runtime assert 544 vr = ValueRanges(vr.lower, int_oo) 545 if not shape_env._default_unspecified_value_range().issubset(vr): 546 # The runtime range is constrained, so add a runtime 547 # assert and also explicitly refine the range 548 # (refinement should not be necessary once runtime 549 # asserts cause refinement, but that's NYI) 550 def convert(s): 551 if s in (int_oo, -int_oo): 552 return None 553 try: 554 return int(s) 555 except TypeError: 556 return None 557 558 if ( 559 expr_to_proxy[i0].node.target 560 != cast_symbool_to_symint_guardless 561 ): 562 # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts 563 # raises AOTAutograd errors on cast_symbool_to_symint_guardless 564 565 with _set_node_metadata_hook( 566 gm, 567 functools.partial( 568 _node_metadata_hook, 569 stack_trace=node.meta.get("stack_trace"), 570 nn_module_stack=node.meta.get("nn_module_stack"), 571 ), 572 ): 573 if (min_val := convert(vr.lower)) is not None: 574 ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node 575 graph.call_function( 576 torch.ops.aten._assert_scalar.default, 577 ( 578 ge, 579 f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", 580 ), 581 ) 582 added_asserts.add(i0 >= min_val) 583 if (max_val := convert(vr.upper)) is not None: 584 le = _sympy_interp(expr_to_proxy, i0 <= max_val).node 585 graph.call_function( 586 torch.ops.aten._assert_scalar.default, 587 ( 588 le, 589 f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", 590 ), 591 ) 592 added_asserts.add(i0 <= max_val) 593 594 constrained_unbacked_symbols.add(i0) 595 add_runtime_asserts(ras) 596 597 # delete unused reified symbols 598 for expr, proxy in expr_to_proxy.items(): 599 if ( 600 isinstance(expr, sympy.Symbol) 601 and proxy.node.op != "placeholder" # keep placeholders intact 602 and not proxy.node.users 603 ): 604 log.debug("deleting unused reified symbol for %s", expr) 605 gm.graph.erase_node(proxy.node) 606