xref: /aosp_15_r20/external/executorch/exir/sym_util.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import Iterable, List, Optional, Set, Union
10
11import sympy
12
13import torch
14from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
15
16
17def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
18    """
19    Evaluate a symint to int. Returns None if symint's symoblic expr
20    can not be evaluated to valid integer according to the hints.
21    """
22    if isinstance(symint, int):
23        return symint
24    node = symint.node
25    shape_env = node.shape_env
26    expr = node.expr
27    try:
28        output = shape_env.size_hint(expr)
29    except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
30        return None
31    return int(output)
32
33
34def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int:
35    """
36    Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's
37    upper bound can not be evaluated to valid integer according to the constraints in shape_env.
38    """
39    if isinstance(maybe_symint, int):
40        return maybe_symint
41    node = maybe_symint.node
42    shape_env = node.shape_env
43    expr = node.expr
44    var_range: ValueRanges = bound_sympy(  # pyre-ignore[24]
45        expr, shape_env.var_to_range
46    )
47    upper_bound = var_range.upper
48    # This import is needed temporarily until we update the pinned torch version.
49
50    try:
51        from torch.utils._sympy.numbers import int_oo  # @manual
52    except ImportError:
53        int_oo = None
54
55    if isinstance(upper_bound, sympy.Integer):
56        concrete_upper = int(var_range.upper)
57        assert isinstance(
58            concrete_upper, int
59        ), f"Expect upper bound to be a concrete int but got {concrete_upper}"
60        return concrete_upper
61    elif int_oo is not None and upper_bound is int_oo:
62        return int_oo
63    else:
64        raise RuntimeError(
65            f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}"
66        )
67
68
69def eval_shape(shape: Iterable[Union[int, torch.SymInt]]):  # pyre-ignore[3]
70    """
71    Shape maybe immutable so we return a new shape. Return None for
72    dimensions that are unbacked e.g. first dimension of nonzero's output.
73    """
74    new_shape = []
75    for _, s in enumerate(shape):
76        new_shape.append(eval_expr(s))
77    return new_shape
78
79
80def eval_shape_upper_bound(shape: Iterable[Union[int, torch.SymInt]]) -> List[int]:
81    new_shape = []
82    for _, s in enumerate(shape):
83        new_shape.append(eval_upper_bound(s))
84    return new_shape
85
86
87def collect_free_symbols(
88    shape: Iterable[Union[int, torch.SymInt]]
89) -> Set[sympy.Symbol]:
90    symset = set()
91    for sz in shape:
92        if not isinstance(sz, torch.SymInt):
93            continue
94        symset.update(sz.node.expr.free_symbols)
95    return symset
96