xref: /aosp_15_r20/external/executorch/exir/sym_util.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Iterable, List, Optional, Set, Union
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerimport sympy
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerdef eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]:
18*523fa7a6SAndroid Build Coastguard Worker    """
19*523fa7a6SAndroid Build Coastguard Worker    Evaluate a symint to int. Returns None if symint's symoblic expr
20*523fa7a6SAndroid Build Coastguard Worker    can not be evaluated to valid integer according to the hints.
21*523fa7a6SAndroid Build Coastguard Worker    """
22*523fa7a6SAndroid Build Coastguard Worker    if isinstance(symint, int):
23*523fa7a6SAndroid Build Coastguard Worker        return symint
24*523fa7a6SAndroid Build Coastguard Worker    node = symint.node
25*523fa7a6SAndroid Build Coastguard Worker    shape_env = node.shape_env
26*523fa7a6SAndroid Build Coastguard Worker    expr = node.expr
27*523fa7a6SAndroid Build Coastguard Worker    try:
28*523fa7a6SAndroid Build Coastguard Worker        output = shape_env.size_hint(expr)
29*523fa7a6SAndroid Build Coastguard Worker    except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
30*523fa7a6SAndroid Build Coastguard Worker        return None
31*523fa7a6SAndroid Build Coastguard Worker    return int(output)
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Workerdef eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int:
35*523fa7a6SAndroid Build Coastguard Worker    """
36*523fa7a6SAndroid Build Coastguard Worker    Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's
37*523fa7a6SAndroid Build Coastguard Worker    upper bound can not be evaluated to valid integer according to the constraints in shape_env.
38*523fa7a6SAndroid Build Coastguard Worker    """
39*523fa7a6SAndroid Build Coastguard Worker    if isinstance(maybe_symint, int):
40*523fa7a6SAndroid Build Coastguard Worker        return maybe_symint
41*523fa7a6SAndroid Build Coastguard Worker    node = maybe_symint.node
42*523fa7a6SAndroid Build Coastguard Worker    shape_env = node.shape_env
43*523fa7a6SAndroid Build Coastguard Worker    expr = node.expr
44*523fa7a6SAndroid Build Coastguard Worker    var_range: ValueRanges = bound_sympy(  # pyre-ignore[24]
45*523fa7a6SAndroid Build Coastguard Worker        expr, shape_env.var_to_range
46*523fa7a6SAndroid Build Coastguard Worker    )
47*523fa7a6SAndroid Build Coastguard Worker    upper_bound = var_range.upper
48*523fa7a6SAndroid Build Coastguard Worker    # This import is needed temporarily until we update the pinned torch version.
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker    try:
51*523fa7a6SAndroid Build Coastguard Worker        from torch.utils._sympy.numbers import int_oo  # @manual
52*523fa7a6SAndroid Build Coastguard Worker    except ImportError:
53*523fa7a6SAndroid Build Coastguard Worker        int_oo = None
54*523fa7a6SAndroid Build Coastguard Worker
55*523fa7a6SAndroid Build Coastguard Worker    if isinstance(upper_bound, sympy.Integer):
56*523fa7a6SAndroid Build Coastguard Worker        concrete_upper = int(var_range.upper)
57*523fa7a6SAndroid Build Coastguard Worker        assert isinstance(
58*523fa7a6SAndroid Build Coastguard Worker            concrete_upper, int
59*523fa7a6SAndroid Build Coastguard Worker        ), f"Expect upper bound to be a concrete int but got {concrete_upper}"
60*523fa7a6SAndroid Build Coastguard Worker        return concrete_upper
61*523fa7a6SAndroid Build Coastguard Worker    elif int_oo is not None and upper_bound is int_oo:
62*523fa7a6SAndroid Build Coastguard Worker        return int_oo
63*523fa7a6SAndroid Build Coastguard Worker    else:
64*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
65*523fa7a6SAndroid Build Coastguard Worker            f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}"
66*523fa7a6SAndroid Build Coastguard Worker        )
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker
69*523fa7a6SAndroid Build Coastguard Workerdef eval_shape(shape: Iterable[Union[int, torch.SymInt]]):  # pyre-ignore[3]
70*523fa7a6SAndroid Build Coastguard Worker    """
71*523fa7a6SAndroid Build Coastguard Worker    Shape maybe immutable so we return a new shape. Return None for
72*523fa7a6SAndroid Build Coastguard Worker    dimensions that are unbacked e.g. first dimension of nonzero's output.
73*523fa7a6SAndroid Build Coastguard Worker    """
74*523fa7a6SAndroid Build Coastguard Worker    new_shape = []
75*523fa7a6SAndroid Build Coastguard Worker    for _, s in enumerate(shape):
76*523fa7a6SAndroid Build Coastguard Worker        new_shape.append(eval_expr(s))
77*523fa7a6SAndroid Build Coastguard Worker    return new_shape
78*523fa7a6SAndroid Build Coastguard Worker
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Workerdef eval_shape_upper_bound(shape: Iterable[Union[int, torch.SymInt]]) -> List[int]:
81*523fa7a6SAndroid Build Coastguard Worker    new_shape = []
82*523fa7a6SAndroid Build Coastguard Worker    for _, s in enumerate(shape):
83*523fa7a6SAndroid Build Coastguard Worker        new_shape.append(eval_upper_bound(s))
84*523fa7a6SAndroid Build Coastguard Worker    return new_shape
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker
87*523fa7a6SAndroid Build Coastguard Workerdef collect_free_symbols(
88*523fa7a6SAndroid Build Coastguard Worker    shape: Iterable[Union[int, torch.SymInt]]
89*523fa7a6SAndroid Build Coastguard Worker) -> Set[sympy.Symbol]:
90*523fa7a6SAndroid Build Coastguard Worker    symset = set()
91*523fa7a6SAndroid Build Coastguard Worker    for sz in shape:
92*523fa7a6SAndroid Build Coastguard Worker        if not isinstance(sz, torch.SymInt):
93*523fa7a6SAndroid Build Coastguard Worker            continue
94*523fa7a6SAndroid Build Coastguard Worker        symset.update(sz.node.expr.free_symbols)
95*523fa7a6SAndroid Build Coastguard Worker    return symset
96