xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/functions/basic_ops.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/csrc/autograd/function.h>
6 #include <torch/csrc/autograd/variable.h>
7 
8 #include <memory>
9 #include <string>
10 #include <vector>
11 
12 namespace torch::autograd {
13 
14 struct TORCH_API Error : public Node {
ErrorError15   Error(std::string msg, edge_list&& next_edges)
16       : Node(std::move(next_edges)), msg(std::move(msg)) {}
17 
ErrorError18   Error(std::string msg) : msg(std::move(msg)) {}
19 
20   variable_list apply(variable_list&& inputs) override;
21 
22   void compiled_args(CompiledNodeArgs& args) override;
23   variable_list apply_with_saved(
24       const variable_list& inputs,
25       SwapSavedVariables& saved) override;
26 
27   std::string msg;
28 };
29 
30 // We print grad_fn names in tensor printing. For functions with backward
31 // NYI, grad_fn=<Error> will be printed if we use Error, which is confusing. So
32 // special case with a new NotImplemented function here.
33 struct TORCH_API NotImplemented : public Error {
NotImplementedNotImplemented34   NotImplemented(const std::string& forward_fn, edge_list&& next_edges)
35       : Error(
36             "derivative for " + forward_fn + " is not implemented",
37             std::move(next_edges)) {}
38 
NotImplementedNotImplemented39   NotImplemented(const std::string& forward_fn)
40       : Error("derivative for " + forward_fn + " is not implemented") {}
41 };
42 
43 // Identity in forward, Error in backward. Used to implement
44 // @once_differentiable
45 struct TORCH_API DelayedError : public Node {
DelayedErrorDelayedError46   DelayedError(std::string msg, int64_t num_inputs) : msg(std::move(msg)) {
47     for (const auto _ [[maybe_unused]] : c10::irange(num_inputs)) {
48       add_input_metadata(Node::undefined_input());
49     }
50   }
51 
52   variable_list apply(variable_list&& inputs) override;
53 
54   std::string msg;
55 };
56 
57 struct TORCH_API UndefinedGrad : public Node {
UndefinedGradUndefinedGrad58   UndefinedGrad() {
59     add_input_metadata(Node::undefined_input());
60   }
61 
62   variable_list apply(variable_list&& inputs) override;
63 };
64 
65 struct TORCH_API UndefinedGradBackward : public Node {
UndefinedGradBackwardUndefinedGradBackward66   UndefinedGradBackward(edge_list&& next_edges) : Node(std::move(next_edges)) {}
67 
68   UndefinedGradBackward() = default;
69 
70   variable_list apply(variable_list&& inputs) override;
71 
compiled_argsUndefinedGradBackward72   void compiled_args(CompiledNodeArgs& args) override {}
apply_with_savedUndefinedGradBackward73   variable_list apply_with_saved(
74       const variable_list& inputs,
75       SwapSavedVariables& saved) override {
76     return apply(variable_list(inputs));
77   }
78 };
79 
80 struct TORCH_API GraphRoot : public Node {
GraphRootGraphRoot81   GraphRoot(edge_list functions, variable_list inputs)
82       : Node(std::move(functions)), outputs(std::move(inputs)) {
83     // Ensures calls to stream() on a GraphRoot instance reflect current
84     // stream(s) on devices of root grad tensors at the time the instance is
85     // constructed.
86     for (const auto& t : outputs) {
87       add_input_metadata(t);
88     }
89   }
90 
applyGraphRoot91   variable_list apply(variable_list&& inputs) override {
92     return outputs;
93   }
94 
95   void compiled_args(CompiledNodeArgs& args) override;
96   variable_list apply_with_saved(
97       const variable_list& inputs,
98       SwapSavedVariables& saved) override;
99 
100   variable_list outputs;
101 };
102 
103 struct TORCH_API Identity : public Node {
104   variable_list apply(variable_list&& inputs) override;
105 };
106 
107 } // namespace torch::autograd
108