1 #pragma once 2 3 #include <torch/library.h> 4 5 namespace torch::autograd { 6 7 // Default DispatchKey::Autograd fallback for built-in operators. 8 // Can be registered for custom operators. 9 TORCH_API torch::CppFunction autogradNotImplementedFallback(); 10 11 // Default DispatchKey::AdInplaceOrView fallback for built-in operators 12 // Can be registered for custom operators. 13 TORCH_API torch::CppFunction autogradNotImplementedInplaceOrViewFallback(); 14 15 // Default DispatchKey::Autograd fallback for all other operators (i.e. custom 16 // operators) 17 TORCH_API torch::CppFunction basicAutogradNotImplementedFallback(); 18 19 enum class AutogradFallbackMode { 20 Nothing, // Fallback is a redispatch 21 Warn, // Fallback raises a warning if backward is called 22 Error, // Fallback raises an error if backward is called 23 }; 24 25 // Change the behavior of "basicAutogradNotImplementedFallback" 26 // In Python this is: 27 // - torch._C._set_autograd_fallback_mode(str) -> None 28 // - torch._C._get_autograd_fallback_mode() -> str 29 TORCH_API void setAutogradFallbackMode(AutogradFallbackMode mode); 30 TORCH_API AutogradFallbackMode getAutogradFallbackMode(); 31 32 } // namespace torch::autograd 33