1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom typing import Iterable, List, Optional, Set, Union 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerimport sympy 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._sympy.value_ranges import bound_sympy, ValueRanges 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerdef eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]: 18*523fa7a6SAndroid Build Coastguard Worker """ 19*523fa7a6SAndroid Build Coastguard Worker Evaluate a symint to int. Returns None if symint's symoblic expr 20*523fa7a6SAndroid Build Coastguard Worker can not be evaluated to valid integer according to the hints. 21*523fa7a6SAndroid Build Coastguard Worker """ 22*523fa7a6SAndroid Build Coastguard Worker if isinstance(symint, int): 23*523fa7a6SAndroid Build Coastguard Worker return symint 24*523fa7a6SAndroid Build Coastguard Worker node = symint.node 25*523fa7a6SAndroid Build Coastguard Worker shape_env = node.shape_env 26*523fa7a6SAndroid Build Coastguard Worker expr = node.expr 27*523fa7a6SAndroid Build Coastguard Worker try: 28*523fa7a6SAndroid Build Coastguard Worker output = shape_env.size_hint(expr) 29*523fa7a6SAndroid Build Coastguard Worker except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: 30*523fa7a6SAndroid Build Coastguard Worker return None 31*523fa7a6SAndroid Build Coastguard Worker return int(output) 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Workerdef eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int: 35*523fa7a6SAndroid Build Coastguard Worker """ 36*523fa7a6SAndroid Build Coastguard Worker Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's 37*523fa7a6SAndroid Build Coastguard Worker upper bound can not be evaluated to valid integer according to the constraints in shape_env. 38*523fa7a6SAndroid Build Coastguard Worker """ 39*523fa7a6SAndroid Build Coastguard Worker if isinstance(maybe_symint, int): 40*523fa7a6SAndroid Build Coastguard Worker return maybe_symint 41*523fa7a6SAndroid Build Coastguard Worker node = maybe_symint.node 42*523fa7a6SAndroid Build Coastguard Worker shape_env = node.shape_env 43*523fa7a6SAndroid Build Coastguard Worker expr = node.expr 44*523fa7a6SAndroid Build Coastguard Worker var_range: ValueRanges = bound_sympy( # pyre-ignore[24] 45*523fa7a6SAndroid Build Coastguard Worker expr, shape_env.var_to_range 46*523fa7a6SAndroid Build Coastguard Worker ) 47*523fa7a6SAndroid Build Coastguard Worker upper_bound = var_range.upper 48*523fa7a6SAndroid Build Coastguard Worker # This import is needed temporarily until we update the pinned torch version. 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker try: 51*523fa7a6SAndroid Build Coastguard Worker from torch.utils._sympy.numbers import int_oo # @manual 52*523fa7a6SAndroid Build Coastguard Worker except ImportError: 53*523fa7a6SAndroid Build Coastguard Worker int_oo = None 54*523fa7a6SAndroid Build Coastguard Worker 55*523fa7a6SAndroid Build Coastguard Worker if isinstance(upper_bound, sympy.Integer): 56*523fa7a6SAndroid Build Coastguard Worker concrete_upper = int(var_range.upper) 57*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 58*523fa7a6SAndroid Build Coastguard Worker concrete_upper, int 59*523fa7a6SAndroid Build Coastguard Worker ), f"Expect upper bound to be a concrete int but got {concrete_upper}" 60*523fa7a6SAndroid Build Coastguard Worker return concrete_upper 61*523fa7a6SAndroid Build Coastguard Worker elif int_oo is not None and upper_bound is int_oo: 62*523fa7a6SAndroid Build Coastguard Worker return int_oo 63*523fa7a6SAndroid Build Coastguard Worker else: 64*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError( 65*523fa7a6SAndroid Build Coastguard Worker f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}" 66*523fa7a6SAndroid Build Coastguard Worker ) 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker 69*523fa7a6SAndroid Build Coastguard Workerdef eval_shape(shape: Iterable[Union[int, torch.SymInt]]): # pyre-ignore[3] 70*523fa7a6SAndroid Build Coastguard Worker """ 71*523fa7a6SAndroid Build Coastguard Worker Shape maybe immutable so we return a new shape. Return None for 72*523fa7a6SAndroid Build Coastguard Worker dimensions that are unbacked e.g. first dimension of nonzero's output. 73*523fa7a6SAndroid Build Coastguard Worker """ 74*523fa7a6SAndroid Build Coastguard Worker new_shape = [] 75*523fa7a6SAndroid Build Coastguard Worker for _, s in enumerate(shape): 76*523fa7a6SAndroid Build Coastguard Worker new_shape.append(eval_expr(s)) 77*523fa7a6SAndroid Build Coastguard Worker return new_shape 78*523fa7a6SAndroid Build Coastguard Worker 79*523fa7a6SAndroid Build Coastguard Worker 80*523fa7a6SAndroid Build Coastguard Workerdef eval_shape_upper_bound(shape: Iterable[Union[int, torch.SymInt]]) -> List[int]: 81*523fa7a6SAndroid Build Coastguard Worker new_shape = [] 82*523fa7a6SAndroid Build Coastguard Worker for _, s in enumerate(shape): 83*523fa7a6SAndroid Build Coastguard Worker new_shape.append(eval_upper_bound(s)) 84*523fa7a6SAndroid Build Coastguard Worker return new_shape 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker 87*523fa7a6SAndroid Build Coastguard Workerdef collect_free_symbols( 88*523fa7a6SAndroid Build Coastguard Worker shape: Iterable[Union[int, torch.SymInt]] 89*523fa7a6SAndroid Build Coastguard Worker) -> Set[sympy.Symbol]: 90*523fa7a6SAndroid Build Coastguard Worker symset = set() 91*523fa7a6SAndroid Build Coastguard Worker for sz in shape: 92*523fa7a6SAndroid Build Coastguard Worker if not isinstance(sz, torch.SymInt): 93*523fa7a6SAndroid Build Coastguard Worker continue 94*523fa7a6SAndroid Build Coastguard Worker symset.update(sz.node.expr.free_symbols) 95*523fa7a6SAndroid Build Coastguard Worker return symset 96