xref: /aosp_15_r20/external/pytorch/tools/autograd/context.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport functools
2*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import native_function_manager
6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import T
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker# Like tools.api.context.with_native_function, but for
10*da0073e9SAndroid Build Coastguard Worker# NativeFunctionWithDifferentiabilityInfo.
11*da0073e9SAndroid Build Coastguard Workerdef with_native_function_with_differentiability_info(
12*da0073e9SAndroid Build Coastguard Worker    func: Callable[[NFWDI], T]
13*da0073e9SAndroid Build Coastguard Worker) -> Callable[[NFWDI], T]:
14*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
15*da0073e9SAndroid Build Coastguard Worker    def wrapper(f: NFWDI) -> T:
16*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f.func):
17*da0073e9SAndroid Build Coastguard Worker            return func(f)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    return wrapper
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker# Like the above but with an additional dispatch key string argument
23*da0073e9SAndroid Build Coastguard Workerdef with_native_function_with_differentiability_info_and_key(
24*da0073e9SAndroid Build Coastguard Worker    func: Callable[[NFWDI, str], T]
25*da0073e9SAndroid Build Coastguard Worker) -> Callable[[NFWDI, str], T]:
26*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
27*da0073e9SAndroid Build Coastguard Worker    def wrapper(f: NFWDI, key: str) -> T:
28*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f.func):
29*da0073e9SAndroid Build Coastguard Worker            return func(f, key)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    return wrapper
32