1 #pragma once 2 3 #include <c10/util/irange.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/csrc/autograd/function.h> 6 #include <torch/csrc/autograd/variable.h> 7 8 #include <memory> 9 #include <string> 10 #include <vector> 11 12 namespace torch::autograd { 13 14 struct TORCH_API Error : public Node { ErrorError15 Error(std::string msg, edge_list&& next_edges) 16 : Node(std::move(next_edges)), msg(std::move(msg)) {} 17 ErrorError18 Error(std::string msg) : msg(std::move(msg)) {} 19 20 variable_list apply(variable_list&& inputs) override; 21 22 void compiled_args(CompiledNodeArgs& args) override; 23 variable_list apply_with_saved( 24 const variable_list& inputs, 25 SwapSavedVariables& saved) override; 26 27 std::string msg; 28 }; 29 30 // We print grad_fn names in tensor printing. For functions with backward 31 // NYI, grad_fn=<Error> will be printed if we use Error, which is confusing. So 32 // special case with a new NotImplemented function here. 33 struct TORCH_API NotImplemented : public Error { NotImplementedNotImplemented34 NotImplemented(const std::string& forward_fn, edge_list&& next_edges) 35 : Error( 36 "derivative for " + forward_fn + " is not implemented", 37 std::move(next_edges)) {} 38 NotImplementedNotImplemented39 NotImplemented(const std::string& forward_fn) 40 : Error("derivative for " + forward_fn + " is not implemented") {} 41 }; 42 43 // Identity in forward, Error in backward. Used to implement 44 // @once_differentiable 45 struct TORCH_API DelayedError : public Node { DelayedErrorDelayedError46 DelayedError(std::string msg, int64_t num_inputs) : msg(std::move(msg)) { 47 for (const auto _ [[maybe_unused]] : c10::irange(num_inputs)) { 48 add_input_metadata(Node::undefined_input()); 49 } 50 } 51 52 variable_list apply(variable_list&& inputs) override; 53 54 std::string msg; 55 }; 56 57 struct TORCH_API UndefinedGrad : public Node { UndefinedGradUndefinedGrad58 UndefinedGrad() { 59 add_input_metadata(Node::undefined_input()); 60 } 61 62 variable_list apply(variable_list&& inputs) override; 63 }; 64 65 struct TORCH_API UndefinedGradBackward : public Node { UndefinedGradBackwardUndefinedGradBackward66 UndefinedGradBackward(edge_list&& next_edges) : Node(std::move(next_edges)) {} 67 68 UndefinedGradBackward() = default; 69 70 variable_list apply(variable_list&& inputs) override; 71 compiled_argsUndefinedGradBackward72 void compiled_args(CompiledNodeArgs& args) override {} apply_with_savedUndefinedGradBackward73 variable_list apply_with_saved( 74 const variable_list& inputs, 75 SwapSavedVariables& saved) override { 76 return apply(variable_list(inputs)); 77 } 78 }; 79 80 struct TORCH_API GraphRoot : public Node { GraphRootGraphRoot81 GraphRoot(edge_list functions, variable_list inputs) 82 : Node(std::move(functions)), outputs(std::move(inputs)) { 83 // Ensures calls to stream() on a GraphRoot instance reflect current 84 // stream(s) on devices of root grad tensors at the time the instance is 85 // constructed. 86 for (const auto& t : outputs) { 87 add_input_metadata(t); 88 } 89 } 90 applyGraphRoot91 variable_list apply(variable_list&& inputs) override { 92 return outputs; 93 } 94 95 void compiled_args(CompiledNodeArgs& args) override; 96 variable_list apply_with_saved( 97 const variable_list& inputs, 98 SwapSavedVariables& saved) override; 99 100 variable_list outputs; 101 }; 102 103 struct TORCH_API Identity : public Node { 104 variable_list apply(variable_list&& inputs) override; 105 }; 106 107 } // namespace torch::autograd 108