xref: /aosp_15_r20/external/pytorch/torch/_functorch/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3from typing import Tuple, Union
4
5import torch
6from torch._C._functorch import (
7    get_single_level_autograd_function_allowed,
8    set_single_level_autograd_function_allowed,
9    unwrap_if_dead,
10)
11from torch.utils._exposed_in import exposed_in
12
13
14__all__ = [
15    "exposed_in",
16    "argnums_t",
17    "enable_single_level_autograd_function",
18    "unwrap_dead_wrappers",
19]
20
21
22@contextlib.contextmanager
23def enable_single_level_autograd_function():
24    try:
25        prev_state = get_single_level_autograd_function_allowed()
26        set_single_level_autograd_function_allowed(True)
27        yield
28    finally:
29        set_single_level_autograd_function_allowed(prev_state)
30
31
32def unwrap_dead_wrappers(args):
33    # NB: doesn't use tree_map_only for performance reasons
34    result = tuple(
35        unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
36    )
37    return result
38
39
40argnums_t = Union[int, Tuple[int, ...]]
41