1*da0073e9SAndroid Build Coastguard Workerimport functools 2*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI 5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import native_function_manager 6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import T 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker# Like tools.api.context.with_native_function, but for 10*da0073e9SAndroid Build Coastguard Worker# NativeFunctionWithDifferentiabilityInfo. 11*da0073e9SAndroid Build Coastguard Workerdef with_native_function_with_differentiability_info( 12*da0073e9SAndroid Build Coastguard Worker func: Callable[[NFWDI], T] 13*da0073e9SAndroid Build Coastguard Worker) -> Callable[[NFWDI], T]: 14*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 15*da0073e9SAndroid Build Coastguard Worker def wrapper(f: NFWDI) -> T: 16*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f.func): 17*da0073e9SAndroid Build Coastguard Worker return func(f) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker return wrapper 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker# Like the above but with an additional dispatch key string argument 23*da0073e9SAndroid Build Coastguard Workerdef with_native_function_with_differentiability_info_and_key( 24*da0073e9SAndroid Build Coastguard Worker func: Callable[[NFWDI, str], T] 25*da0073e9SAndroid Build Coastguard Worker) -> Callable[[NFWDI, str], T]: 26*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 27*da0073e9SAndroid Build Coastguard Worker def wrapper(f: NFWDI, key: str) -> T: 28*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f.func): 29*da0073e9SAndroid Build Coastguard Worker return func(f, key) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker return wrapper 32