xref: /aosp_15_r20/external/pytorch/tools/autograd/context.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import functools
2from typing import Callable
3
4from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
5from torchgen.context import native_function_manager
6from torchgen.utils import T
7
8
9# Like tools.api.context.with_native_function, but for
10# NativeFunctionWithDifferentiabilityInfo.
11def with_native_function_with_differentiability_info(
12    func: Callable[[NFWDI], T]
13) -> Callable[[NFWDI], T]:
14    @functools.wraps(func)
15    def wrapper(f: NFWDI) -> T:
16        with native_function_manager(f.func):
17            return func(f)
18
19    return wrapper
20
21
22# Like the above but with an additional dispatch key string argument
23def with_native_function_with_differentiability_info_and_key(
24    func: Callable[[NFWDI, str], T]
25) -> Callable[[NFWDI, str], T]:
26    @functools.wraps(func)
27    def wrapper(f: NFWDI, key: str) -> T:
28        with native_function_manager(f.func):
29            return func(f, key)
30
31    return wrapper
32