1# mypy: allow-untyped-defs 2import contextlib 3from typing import Tuple, Union 4 5import torch 6from torch._C._functorch import ( 7 get_single_level_autograd_function_allowed, 8 set_single_level_autograd_function_allowed, 9 unwrap_if_dead, 10) 11from torch.utils._exposed_in import exposed_in 12 13 14__all__ = [ 15 "exposed_in", 16 "argnums_t", 17 "enable_single_level_autograd_function", 18 "unwrap_dead_wrappers", 19] 20 21 22@contextlib.contextmanager 23def enable_single_level_autograd_function(): 24 try: 25 prev_state = get_single_level_autograd_function_allowed() 26 set_single_level_autograd_function_allowed(True) 27 yield 28 finally: 29 set_single_level_autograd_function_allowed(prev_state) 30 31 32def unwrap_dead_wrappers(args): 33 # NB: doesn't use tree_map_only for performance reasons 34 result = tuple( 35 unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args 36 ) 37 return result 38 39 40argnums_t = Union[int, Tuple[int, ...]] 41