1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Iterable, List, Optional, Set, Union 10 11import sympy 12 13import torch 14from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges 15 16 17def eval_expr(symint: Union[int, torch.SymInt]) -> Optional[int]: 18 """ 19 Evaluate a symint to int. Returns None if symint's symoblic expr 20 can not be evaluated to valid integer according to the hints. 21 """ 22 if isinstance(symint, int): 23 return symint 24 node = symint.node 25 shape_env = node.shape_env 26 expr = node.expr 27 try: 28 output = shape_env.size_hint(expr) 29 except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: 30 return None 31 return int(output) 32 33 34def eval_upper_bound(maybe_symint: Union[int, torch.SymInt]) -> int: 35 """ 36 Evaluate a symint to its uppper bound value. Returns None if symint's symoblic expr's 37 upper bound can not be evaluated to valid integer according to the constraints in shape_env. 38 """ 39 if isinstance(maybe_symint, int): 40 return maybe_symint 41 node = maybe_symint.node 42 shape_env = node.shape_env 43 expr = node.expr 44 var_range: ValueRanges = bound_sympy( # pyre-ignore[24] 45 expr, shape_env.var_to_range 46 ) 47 upper_bound = var_range.upper 48 # This import is needed temporarily until we update the pinned torch version. 49 50 try: 51 from torch.utils._sympy.numbers import int_oo # @manual 52 except ImportError: 53 int_oo = None 54 55 if isinstance(upper_bound, sympy.Integer): 56 concrete_upper = int(var_range.upper) 57 assert isinstance( 58 concrete_upper, int 59 ), f"Expect upper bound to be a concrete int but got {concrete_upper}" 60 return concrete_upper 61 elif int_oo is not None and upper_bound is int_oo: 62 return int_oo 63 else: 64 raise RuntimeError( 65 f"Expect upper bound to be sympy.Integer or int_oo. but got {upper_bound}" 66 ) 67 68 69def eval_shape(shape: Iterable[Union[int, torch.SymInt]]): # pyre-ignore[3] 70 """ 71 Shape maybe immutable so we return a new shape. Return None for 72 dimensions that are unbacked e.g. first dimension of nonzero's output. 73 """ 74 new_shape = [] 75 for _, s in enumerate(shape): 76 new_shape.append(eval_expr(s)) 77 return new_shape 78 79 80def eval_shape_upper_bound(shape: Iterable[Union[int, torch.SymInt]]) -> List[int]: 81 new_shape = [] 82 for _, s in enumerate(shape): 83 new_shape.append(eval_upper_bound(s)) 84 return new_shape 85 86 87def collect_free_symbols( 88 shape: Iterable[Union[int, torch.SymInt]] 89) -> Set[sympy.Symbol]: 90 symset = set() 91 for sz in shape: 92 if not isinstance(sz, torch.SymInt): 93 continue 94 symset.update(sz.node.expr.free_symbols) 95 return symset 96