1import functools 2from typing import Callable 3 4from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI 5from torchgen.context import native_function_manager 6from torchgen.utils import T 7 8 9# Like tools.api.context.with_native_function, but for 10# NativeFunctionWithDifferentiabilityInfo. 11def with_native_function_with_differentiability_info( 12 func: Callable[[NFWDI], T] 13) -> Callable[[NFWDI], T]: 14 @functools.wraps(func) 15 def wrapper(f: NFWDI) -> T: 16 with native_function_manager(f.func): 17 return func(f) 18 19 return wrapper 20 21 22# Like the above but with an additional dispatch key string argument 23def with_native_function_with_differentiability_info_and_key( 24 func: Callable[[NFWDI, str], T] 25) -> Callable[[NFWDI, str], T]: 26 @functools.wraps(func) 27 def wrapper(f: NFWDI, key: str) -> T: 28 with native_function_manager(f.func): 29 return func(f, key) 30 31 return wrapper 32