xref: /aosp_15_r20/external/pytorch/torch/fx/passes/runtime_assert.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3import logging
4import operator
5import sys
6from typing import Any, Dict, Optional, Set, TYPE_CHECKING
7
8
9# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow
10if TYPE_CHECKING:
11    import sympy
12
13    from torch.fx.experimental.symbolic_shapes import ShapeEnv
14else:
15    ShapeEnv = Any
16
17import torch
18import torch.utils._pytree as pytree
19from torch import fx
20from torch._subclasses.meta_utils import is_sparse_any
21from torch.fx._compatibility import compatibility
22from torch.fx._utils import lazy_format_graph_code
23from torch.fx.experimental.proxy_tensor import py_sym_types
24from torch.fx.experimental.sym_node import SymNode
25from torch.fx.graph_module import GraphModule
26
27
28__all__ = ["insert_deferred_runtime_asserts"]
29
30log = logging.getLogger(__name__)
31graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
32
33
34def _get_example_value(node: fx.Node) -> Optional[str]:
35    """
36    Get the example value key for a node, since dynamo uses "example_value"
37    while non-strict export uses "val.
38    """
39    if "example_value" in node.meta:
40        return node.meta["example_value"]
41    elif "val" in node.meta:
42        return node.meta["val"]
43    else:
44        return None
45
46
47def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]:
48    val = _get_example_value(node)
49    if isinstance(val, py_sym_types):
50        return val.node.expr
51    return None
52
53
54@compatibility(is_backward_compatible=True)
55def insert_deferred_runtime_asserts(
56    gm: GraphModule,
57    shape_env: ShapeEnv,
58    name: str,
59    export: bool = False,
60) -> None:
61    """
62    During tracing, we may have discovered that some data-dependent values
63    had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime
64    that x.item() >= 0.  This asserts can happen unpredictably during fake
65    tensor propagation, so we cannot conveniently insert them into the FX graph
66    when they occur.  Instead, we accumulate them in the ShapeEnv, and in this
67    pass insert them into the graph as proper tests.
68
69    This pass also deduplicates size-related computation, CSE-ing ops that produce
70    symbolic values and/or are involved in runtime asserts. Additionally, shape calls
71    (size/stride/storage_offset) are turned into compute on input sizes if possible,
72    allowing intermediate tensors to be freed earlier. For example, here dynamo will
73    DCE the cat and repeat calls:
74
75        z = torch.cat([x, x], dim=0)  # 2*s0
76        w = z.repeat(y.shape[0])  # 2*s0*s1
77        _w = w.shape[0]
78        # something with _w, but not w ...
79
80        # turns into ->
81        _w0 = 2 * s0
82        _w = _w0 * s1
83
84        # where s0, s1 are either SymInt graph inputs, or the result of added size calls
85
86    Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert
87    the same expression, and redundant constrain_range calls are also deduplicated.
88    Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate
89    information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol,
90    and we delete all previous calls, adding bound checks at the end of this pass.
91    """
92
93    # Import sympy locally
94    import sympy
95
96    from torch._export.passes._node_metadata_hook import _set_node_metadata_hook
97    from torch.fx.experimental.symbolic_shapes import (
98        _has_uninterpretable_sympy_function,
99        CallMethodKey,
100        cast_symbool_to_symint_guardless,
101        ConvertIntKey,
102        DivideByKey,
103        free_symbols,
104        InnerTensorKey,
105        resolve_unbacked_bindings,
106    )
107    from torch.utils._sympy.numbers import int_oo
108    from torch.utils._sympy.reference import PythonReferenceAnalysis
109    from torch.utils._sympy.value_ranges import ValueRanges
110
111    # TODO: Request simplification on runtime asserts before emitting them
112    ras_by_symbol = shape_env.deferred_runtime_asserts.copy()
113    graph = gm.graph
114    graph_code_log.debug(
115        "%s",
116        lazy_format_graph_code(
117            f"pre insert_deferred_runtime_asserts {name}", gm, colored=True
118        ),
119    )
120
121    # We are going to mutate the dict
122    expr_to_proxy: Dict[sympy.Expr, fx.Proxy] = {}
123    placeholders = set()
124    first_non_placeholder = None
125    for node in graph.nodes:
126        if node.op != "placeholder":
127            first_non_placeholder = node
128            break
129        else:
130            placeholders.add(node)
131
132    def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool:
133        """
134        If a size/stride/storage offset call on an intermediate tensor,
135        we can try to compute the value from input shapes instead.
136        """
137        return (
138            (val := _get_sym_val(node)) is not None
139            and not isinstance(val, sympy.Number)
140            # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
141            and not _has_uninterpretable_sympy_function(val)
142            and any(
143                isinstance(arg, fx.Node)
144                and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
145                and arg.op != "placeholder"
146                for arg in node.args
147            )
148        )
149
150    # Figure out what key to use, val or example_value
151    val_key = "val"
152    for node in graph.nodes:
153        if "example_value" in node.meta:
154            val_key = "example_value"
155            break
156        elif "val" in node.meta:
157            break
158
159    def _node_metadata_hook(
160        node: torch.fx.Node,
161        stack_trace: Optional[str] = None,
162        nn_module_stack: Optional[Dict[str, Any]] = None,
163    ) -> None:
164        fake_args = [
165            _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg
166            for arg in node.args
167        ]
168        try:
169            node.meta[val_key] = node.target(*fake_args)  # type: ignore[operator]
170        except NotImplementedError:
171            # This can happen when attempting to reify a symbol with an unsupported call_function node,
172            # e.g. with NestedTensors + sym_size.int via match_symbol().
173            # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input.
174            pass
175        if stack_trace is not None:
176            node.meta["stack_trace"] = stack_trace
177        if nn_module_stack is not None:
178            node.meta["nn_module_stack"] = nn_module_stack
179
180    # Track asserts/checks we've added
181    added_asserts: Set[sympy.Expr] = set()
182    constrained_unbacked_symbols: Set[sympy.Symbol] = set()
183
184    def _sympy_interp(expr_to_proxy, expr):
185        # sympy_interp() with hash consing
186        from sympy import Integer, Number, Symbol
187        from sympy.logic.boolalg import BooleanAtom
188
189        from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp
190
191        # hash cons
192        if expr in expr_to_proxy:
193            return expr_to_proxy[expr]
194        # base cases, don't cache
195        if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)):
196            return sympy_interp(PythonReferenceAnalysis, expr_to_proxy, expr)
197
198        # hash cons on arguments, run expr handler
199        expr_to_proxy[expr] = _run_sympy_handler(
200            PythonReferenceAnalysis,
201            [_sympy_interp(expr_to_proxy, arg) for arg in expr.args],
202            expr,
203        )
204        return expr_to_proxy[expr]
205
206    def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool:
207        # This is probably unnecessary, but since torch._check() calls for single-symbol bounds
208        # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant
209        # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial.
210        if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan):
211            return False
212        lhs, rhs = expr.args
213        return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or (
214            isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number)
215        )
216
217    def add_runtime_asserts(ras):
218        for ra in ras:
219            if (
220                # redundant
221                ra.expr in added_asserts
222                # if we've already added a constrain_range call for this symbol,
223                # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant.
224                or (
225                    len(ra.expr.free_symbols) == 1
226                    and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols
227                    and _is_bound_expr_for_symbol(ra.expr)
228                )
229                # don't try to reify sympy functions we can't turn into FX nodes
230                or _has_uninterpretable_sympy_function(ra.expr)
231            ):
232                continue
233
234            log.debug("inserting runtime assert %s", ra.expr)
235            # Need to process ALL free symbols, not just unbacked ones
236            fvs = free_symbols(ra.expr)
237            missing = fvs - expr_to_proxy.keys()
238            if missing:
239                i1 = min(missing, key=str)
240                # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689
241                # assert shape_env.is_unbacked_symint(i1), i1
242                ras_by_symbol.setdefault(i1, []).append(ra)
243            else:
244                # Convert the sympy expression into a sequence of FX
245                # nodes
246                with _set_node_metadata_hook(gm, _node_metadata_hook):
247                    res = _sympy_interp(expr_to_proxy, ra.expr).node
248                    graph.call_function(
249                        torch.ops.aten._assert_scalar.default,
250                        # TODO: use ra.msg here, but it's pretty
251                        # useless right now
252                        (
253                            res,
254                            f"Runtime assertion failed for expression {ra.expr} on node '{res}'",
255                        ),
256                    )
257                added_asserts.add(ra.expr)
258
259    nodes = list(graph.nodes)
260    for i, node in enumerate(nodes[:-1]):
261        # Placeholders can match symbols, but when we destructure them
262        # with size we have to make sure we insert the nodes after all
263        # the placeholders
264        with graph.inserting_before(
265            nodes[i + 1] if node not in placeholders else first_non_placeholder
266        ):
267            # Unfortunately, this logic still must remain because manual
268            # make_fx calls may not explicitly bind all symbolic ints as
269            # arguments to the function, so we must infer it from the other
270            # arguments
271            if (
272                node in placeholders
273                and (example_value := _get_example_value(node)) is not None
274            ):
275
276                def match_symbol(symint, cb):
277                    if (
278                        isinstance(symint, torch.SymInt)
279                        and isinstance(symint.node, SymNode)
280                        and isinstance(s := symint.node.expr, sympy.Symbol)
281                        and s not in expr_to_proxy
282                    ):
283                        with _set_node_metadata_hook(gm, _node_metadata_hook):
284                            expr_to_proxy[s] = fx.Proxy(cb())
285                        log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
286
287                match_symbol(example_value, lambda: node)
288                if isinstance(t := example_value, torch.Tensor):
289                    for i, s in enumerate(t.size()):
290                        match_symbol(
291                            s,
292                            lambda: graph.call_function(
293                                torch.ops.aten.sym_size.int, (node, i)
294                            ),
295                        )
296                    if not is_sparse_any(t):
297                        for i, s in enumerate(t.stride()):
298                            match_symbol(
299                                s,
300                                lambda: graph.call_function(
301                                    torch.ops.aten.sym_stride.int, (node, i)
302                                ),
303                            )
304                        match_symbol(
305                            t.storage_offset(),
306                            lambda: graph.call_function(
307                                torch.ops.aten.sym_storage_offset.default, (node,)
308                            ),
309                        )
310
311            # Handle asserts that aren't associated with any symbol.  This
312            # doesn't really have to be in the loop as it will only run once,
313            # it just needs to happen right after the placeholders.
314            # insert this after placeholders & added sym nodes, and before non-placeholders.
315            if node == first_non_placeholder:
316                add_runtime_asserts(ras_by_symbol.pop(None, []))  # type: ignore[call-overload]
317
318            # deduplicate asserts already present in graph
319            if node.target in (
320                torch._check,
321                torch.ops.aten._assert_scalar.default,
322            ):
323                if (
324                    node.args[0] == True  # noqa: E712
325                    or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy
326                    or (
327                        assert_expr is not None
328                        and _is_bound_expr_for_symbol(assert_expr)
329                    )
330                ):
331                    arg = node.args[0]
332                    gm.graph.erase_node(node)
333                    if isinstance(arg, fx.Node) and not arg.users:
334                        gm.graph.erase_node(arg)
335                else:
336                    added_asserts.add(assert_expr)  # type: ignore[arg-type]
337
338            # hash cons, replace function calls that return torch.SymInts with direct references to
339            # FX nodes built up to reify the sympy expression.
340            if (
341                node.op != "placeholder"
342                and (sym_expr := _get_sym_val(node)) is not None
343            ):
344                # this guards against deleting calls like item() that produce new untracked symbols
345                new_untracked_symbols = sym_expr.free_symbols - expr_to_proxy.keys()
346                # this guards against deleting calls that produce unbacked bindings we haven't yet seen.
347                # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint
348                # (is backed), but produces an unbacked symbol. In this case keep the node alive.
349                new_unbacked_bindings = (
350                    resolve_unbacked_bindings(
351                        shape_env, node.meta.get("unbacked_bindings", {})
352                    ).keys()
353                    - expr_to_proxy.keys()
354                )
355
356                # maybe re-reify expression, replace current node
357                if (
358                    sym_expr in expr_to_proxy
359                    or (  # example value is redundant
360                        _is_intermediate_tensor_sym_call(node)
361                        # shape call on intermediate tensor, turn into computation on input shapes
362                        and not new_untracked_symbols
363                    )
364                ) and not new_unbacked_bindings:
365                    if _is_intermediate_tensor_sym_call(
366                        node
367                    ):  # reify from input shapes
368                        with _set_node_metadata_hook(
369                            gm,
370                            functools.partial(
371                                _node_metadata_hook,
372                                stack_trace=node.meta.get("stack_trace"),
373                                nn_module_stack=node.meta.get("nn_module_stack"),
374                            ),
375                        ):
376                            expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr)  # type: ignore[arg-type]
377                        # won't try DCE-ing tensor compute here
378                    hash_node = expr_to_proxy[sym_expr].node  # type: ignore[arg-type]
379                    node.replace_all_uses_with(hash_node)
380                    gm.graph.erase_node(node)
381                    log.debug(
382                        "CSE node %s -> %s for expr %s", node, hash_node, sym_expr
383                    )
384
385                # store node in hash cons, don't delete/replace
386                elif sym_expr not in expr_to_proxy and not isinstance(
387                    sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
388                ):  # don't hash cons primitives
389                    expr_to_proxy[sym_expr] = fx.Proxy(node)  # type: ignore[arg-type]
390
391            # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained,
392            # so calls before that are redundant.
393            if node.target in (
394                torch.ops.aten.sym_constrain_range.default,
395                torch.ops.aten.sym_constrain_range_for_size.default,
396            ):
397                gm.graph.erase_node(node)
398
399            defs = []
400
401            # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as
402            # equivalent, but the refinement calls we perform in this pass may struggle with associating the two.
403            # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough
404            # information about the old symbol when we re-export, raising errors on data-dependent guards.
405            # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is.
406            if unbacked_bindings := resolve_unbacked_bindings(
407                shape_env, node.meta.get("unbacked_bindings")
408            ):
409                for s, keypath in unbacked_bindings.items():
410                    defs.append(s)
411
412                    # TODO: some CSE when generating these nodes can probably
413                    # help reduce graph size and improve compile time
414                    def go(node, keypath):
415                        if keypath == ():
416                            return node
417                        if (
418                            len(keypath) >= 2
419                            and isinstance(keypath[0], CallMethodKey)
420                            and isinstance(keypath[1], pytree.SequenceKey)
421                        ):
422                            if keypath[0].name == "size":
423                                return go(
424                                    graph.call_function(
425                                        torch.ops.aten.sym_size.int,
426                                        (node, keypath[1].idx),
427                                    ),
428                                    keypath[2:],
429                                )
430                            if keypath[0].name == "stride":
431                                return go(
432                                    graph.call_function(
433                                        torch.ops.aten.sym_stride.int,
434                                        (node, keypath[1].idx),
435                                    ),
436                                    keypath[2:],
437                                )
438                            return go(
439                                graph.call_method(
440                                    keypath[0].name, (node, keypath[1].idx)
441                                ),
442                                keypath[2:],
443                            )
444                        elif isinstance(keypath[0], CallMethodKey):
445                            return go(
446                                graph.call_method(keypath[0].name, (node,)), keypath[1:]
447                            )
448                        elif isinstance(keypath[0], pytree.SequenceKey):
449                            return go(
450                                graph.call_function(
451                                    operator.getitem, (node, keypath[0].idx)
452                                ),
453                                keypath[1:],
454                            )
455                        elif isinstance(keypath[0], ConvertIntKey):
456                            return go(
457                                graph.call_function(
458                                    cast_symbool_to_symint_guardless, (node,)
459                                ),
460                                keypath[1:],
461                            )
462                        elif isinstance(keypath[0], DivideByKey):
463                            # TODO: need to assert divisibility
464                            return go(
465                                graph.call_function(
466                                    operator.floordiv, (node, keypath[0].divisor)
467                                ),
468                                keypath[1:],
469                            )
470                        elif isinstance(keypath[0], InnerTensorKey):
471                            return go(
472                                graph.call_function(
473                                    getattr, (node, keypath[0].inner_name)
474                                ),
475                                keypath[1:],
476                            )
477                        else:
478                            raise AssertionError(f"unrecognized keypath {keypath}")
479
480                    if s not in expr_to_proxy:
481                        with _set_node_metadata_hook(gm, _node_metadata_hook):
482                            expr_to_proxy[s] = fx.Proxy(go(node, keypath))
483                        log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
484
485            for i0 in defs:
486                ras = ras_by_symbol.pop(i0, [])
487                # Before we perform any asserts, first apply range
488                # refinement.  This is important, because if we are going
489                # to retrace the graph (and we typically are if we send
490                # the graph to AOTAutograd), we need to make sure we apply
491                # range refinement (ala _check_is_size) first, BEFORE we
492                # run any of the asserts.  Otherwise, we may decide to
493                # perform substitutions based on the asserts which we then
494                # can't back out, because value ranges can only be applied
495                # to asserts.)
496                #
497                # A perhaps better long term plan is to avoid this order
498                # dependence by making it possible to refine ranges on
499                # arbitrary expressions, not just symbols.  But it is not
500                # so easy to make use of this information, see
501                # https://twitter.com/ezyang/status/1745801370299482492
502                # We actually made an attempt at this in
503                # https://github.com/pytorch/pytorch/pull/119043
504                # which didn't work.
505                #
506                # Another ideas for how to do this:
507                # - Have bound_sympy be the source of truth of the ranges of any expression
508                # - Cache intermediate results for every subexpression of bound_sympy
509                # - This cache should be possible to edit to refine ranges
510                #
511                # One issue with this proposal is that if
512                # we have a bound on 2x, we are not going to be able to
513                # apply it for 4x.  Similarly, we may have bounds for an
514                # equivalent expression that we are not applying because
515                # it's not a perfect match (e.g. x < y vs y > x)".
516                #
517                # The first issue we already have it and it's impossible
518                # to solve in general, so any implementation on a best
519                # effort basis should do.
520                #
521                # The second issue is a preexisting one. It can be mitigated
522                # with a normalisation algorithm. In general, it may also
523                # be on a best effort basis, but since our grammar is not
524                # terribly difficult, chances are we could even fully
525                # normalise SymPy expressions... who knows.
526                if i0 in constrained_unbacked_symbols:
527                    continue  # constrain symbol just once
528
529                if i0 in shape_env.size_like:
530                    if export:
531                        graph.call_function(
532                            torch.ops.aten.sym_constrain_range_for_size.default,
533                            (expr_to_proxy[i0].node,),
534                        )
535                    else:
536                        graph.call_function(
537                            torch._check_is_size, (expr_to_proxy[i0].node,)
538                        )
539
540                vr = shape_env.var_to_range[i0]
541                if vr.is_int and vr.upper == sys.maxsize - 1:
542                    # treat upper bound == sys.maxsize - 1 for int symbols as +oo
543                    # to avoid redundant runtime assert
544                    vr = ValueRanges(vr.lower, int_oo)
545                if not shape_env._default_unspecified_value_range().issubset(vr):
546                    # The runtime range is constrained, so add a runtime
547                    # assert and also explicitly refine the range
548                    # (refinement should not be necessary once runtime
549                    # asserts cause refinement, but that's NYI)
550                    def convert(s):
551                        if s in (int_oo, -int_oo):
552                            return None
553                        try:
554                            return int(s)
555                        except TypeError:
556                            return None
557
558                    if (
559                        expr_to_proxy[i0].node.target
560                        != cast_symbool_to_symint_guardless
561                    ):
562                        # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts
563                        # raises AOTAutograd errors on cast_symbool_to_symint_guardless
564
565                        with _set_node_metadata_hook(
566                            gm,
567                            functools.partial(
568                                _node_metadata_hook,
569                                stack_trace=node.meta.get("stack_trace"),
570                                nn_module_stack=node.meta.get("nn_module_stack"),
571                            ),
572                        ):
573                            if (min_val := convert(vr.lower)) is not None:
574                                ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node
575                                graph.call_function(
576                                    torch.ops.aten._assert_scalar.default,
577                                    (
578                                        ge,
579                                        f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'",
580                                    ),
581                                )
582                                added_asserts.add(i0 >= min_val)
583                            if (max_val := convert(vr.upper)) is not None:
584                                le = _sympy_interp(expr_to_proxy, i0 <= max_val).node
585                                graph.call_function(
586                                    torch.ops.aten._assert_scalar.default,
587                                    (
588                                        le,
589                                        f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'",
590                                    ),
591                                )
592                                added_asserts.add(i0 <= max_val)
593
594                constrained_unbacked_symbols.add(i0)
595                add_runtime_asserts(ras)
596
597    # delete unused reified symbols
598    for expr, proxy in expr_to_proxy.items():
599        if (
600            isinstance(expr, sympy.Symbol)
601            and proxy.node.op != "placeholder"  # keep placeholders intact
602            and not proxy.node.users
603        ):
604            log.debug("deleting unused reified symbol for %s", expr)
605            gm.graph.erase_node(proxy.node)
606