xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/library.h>
4 
5 namespace torch::autograd {
6 
7 // Default DispatchKey::Autograd fallback for built-in operators.
8 // Can be registered for custom operators.
9 TORCH_API torch::CppFunction autogradNotImplementedFallback();
10 
11 // Default DispatchKey::AdInplaceOrView fallback for built-in operators
12 // Can be registered for custom operators.
13 TORCH_API torch::CppFunction autogradNotImplementedInplaceOrViewFallback();
14 
15 // Default DispatchKey::Autograd fallback for all other operators (i.e. custom
16 // operators)
17 TORCH_API torch::CppFunction basicAutogradNotImplementedFallback();
18 
19 enum class AutogradFallbackMode {
20   Nothing, // Fallback is a redispatch
21   Warn, // Fallback raises a warning if backward is called
22   Error, // Fallback raises an error if backward is called
23 };
24 
25 // Change the behavior of "basicAutogradNotImplementedFallback"
26 // In Python this is:
27 // - torch._C._set_autograd_fallback_mode(str) -> None
28 // - torch._C._get_autograd_fallback_mode() -> str
29 TORCH_API void setAutogradFallbackMode(AutogradFallbackMode mode);
30 TORCH_API AutogradFallbackMode getAutogradFallbackMode();
31 
32 } // namespace torch::autograd
33