xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/custom_function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ivalue.h>
4 #include <c10/core/SymInt.h>
5 #include <c10/util/flat_hash_map.h>
6 #include <c10/util/irange.h>
7 #include <torch/csrc/autograd/function.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/autograd/variable_info.h>
10 #include <torch/csrc/dynamo/compiled_autograd.h>
11 #include <vector>
12 
13 namespace torch::autograd {
14 
15 using optional_variable_list = std::vector<std::optional<Variable>>;
16 using _jvp_fn_t = std::function<variable_list(variable_list, variable_list)>;
17 using _view_as_self_fn_t = std::function<at::Tensor(at::Tensor)>;
18 
19 TORCH_API std::vector<std::optional<Variable>> _wrap_outputs(
20     const variable_list& input_vars,
21     const std::unordered_set<at::TensorImpl*>& non_differentiable,
22     const std::unordered_set<at::TensorImpl*>& dirty_inputs,
23     const at::ArrayRef<std::optional<Variable>> raw_outputs,
24     const std::shared_ptr<Node>& cdata,
25     const _jvp_fn_t& jvp_user_function,
26     const std::unordered_set<at::TensorImpl*>& to_save_if_setup_context,
27     const _view_as_self_fn_t& view_as_self_fn);
28 
29 TORCH_API void check_variable_result(
30     const at::TensorBase& original,
31     const at::TensorBase& result,
32     const std::string& hook_name);
33 
34 // Get the return type of the forward function of the custom Function class X
35 template <typename X, typename... Args>
36 using forward_t = decltype(X::forward(nullptr, std::declval<Args>()...));
37 
38 /// To use custom autograd operations, implement a Function subclass with
39 /// static forward and backward functions:
40 ///
41 /// `forward` can take as many arguments as you want and should return either a
42 /// variable list or a Variable. Use of any direct Variable arguments will be
43 /// registered in the graph but no vectors/sets or any other data structures
44 /// will be traversed. You can use std::optional<Tensor> as one of the arguments
45 /// and it will be registered as a variable in the graph if the argument has a
46 /// value. It should take a pointer to `torch::autograd::AutogradContext` as the
47 /// first argument. Variables can be saved in the `ctx` using
48 /// `ctx->save_for_backward`
49 /// (see `torch::autograd::AutogradContext::save_for_backward`) and other data
50 /// can be saved in the `ctx->saved_data` map
51 /// (see `torch::autograd::AutogradContext::saved_data`)
52 /// in the form of `<std::string, at::IValue>` pairs.
53 ///
54 /// `backward` should take a pointer to `torch::autograd::AutogradContext`
55 /// and a variable list containing as many Variables as there were outputs from
56 /// `forward` as arguments. It should return as many Variables as there were
57 /// inputs with each of them containing the gradient w.r.t. its corresponding
58 /// input. Variables saved in `forward` can be accessed with
59 /// `ctx->get_saved_variables` (see
60 /// `torch::autograd::AutogradContext::get_saved_variables`) and other saved
61 /// data can be accessed from `ctx->saved_data`.
62 /// To enable compiled autograd support (torch.compile for backward) for your
63 /// custom autograd operation, you can set MyFunction::is_traceable
64 /// (see Function::istraceable notes below).
65 ///
66 /// For example:
67 /// ```
68 /// class MyFunction : public Function<MyFunction> {
69 ///   public:
70 ///   static constexpr bool is_traceable = true;
71 ///
72 ///   static variable_list forward(AutogradContext *ctx, int n, Variable var) {
73 ///      // Save data for backward in context
74 ///      ctx->saved_data["n"] = n;
75 ///      var.mul_(n);
76 ///      // Mark var as modified by inplace operation
77 ///      ctx->mark_dirty({var});
78 ///      return {var};
79 ///   }
80 ///
81 ///   static variable_list backward(AutogradContext *ctx, variable_list
82 ///   grad_output) {
83 ///      // Use data saved in forward
84 ///      auto n = ctx->saved_data["n"].toInt();
85 ///      return {grad_output[0]*n};
86 ///   }
87 /// };
88 /// ```
89 ///
90 /// To use `MyFunction`:
91 /// ```
92 /// Variable x;
93 /// auto y = MyFunction::apply(6, x);
94 /// // Example backward call
95 /// y[0].sum().backward();
96 /// ```
97 template <class T>
98 struct TORCH_API Function {
99   // We need to use a different template parameter than T here because T will
100   // inherit from Function, and when Function<T> is instantiated, T::forward
101   // is not declared yet.
102   // The enable_if check is to ensure that the user doesn't explicitly provide
103   // the parameter X.
104   template <typename X = T, typename... Args>
105   static auto apply(Args&&... args)
106       -> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>>;
107 
108   // This flag is for an experimental feature: compiled autograd. Not all
109   // built-in APIs are supported at the moment e.g. mark_dirty and
110   // mark_non_differentiable. Before setting this flag to enable tracing for
111   // your custom function <T>, you need to ensure that the backward function is
112   // traceable i.e. any variables accessed in the backward other than the input
113   // arguments must be handled in a similar manner to built-ins in
114   // CppNode::compiled_args and CppNode::apply_with_saved.
115   static constexpr bool is_traceable = false;
116 };
117 
118 /// Context to save information during `forward` that can be accessed in
119 /// `backward` in custom autograd operations (see `torch::autograd::Function`
120 /// for details).
121 struct TORCH_API AutogradContext {
122   AutogradContext() = default;
123   AutogradContext(const AutogradContext& other) = delete;
124   AutogradContext& operator=(const AutogradContext& other) = delete;
125 
126   /// Can be used to save non-variable data for `backward`.
127   ska::flat_hash_map<std::string, at::IValue> saved_data;
128 
129   /// Saves the list of variables for a future call to `backward`. This
130   /// should be called at most once from inside of `forward`.
131   void save_for_backward(variable_list to_save);
132   /// Marks variables in the list as modified in an in-place operation. This
133   /// should be called at most once from inside of `forward` and all arguments
134   /// should be inputs.
135   void mark_dirty(const variable_list& inputs);
136   /// Marks outputs in the list as not requiring gradients. This should be
137   /// called at most once from inside of `forward` and all arguments should be
138   /// outputs.
139   void mark_non_differentiable(const variable_list& outputs);
140   // Sets whether undefined output grad tensors should be expanded to tensors
141   // full of zeros before calling backward function. Default value is true.
142   void set_materialize_grads(bool value);
143 
144   /// Get the list of variables that were saved in `forward` using
145   /// `save_for_backward()`. Before returning them to the user, a check is made
146   /// to ensure that they were not modified by any in-place operations.
147   variable_list get_saved_variables() const;
148   const std::unordered_set<at::TensorImpl*>& get_and_bump_dirty() const;
149   const std::unordered_set<at::TensorImpl*>& get_non_differentiable() const;
150 
151   /// Expose the Node's `task_should_compute_output` method to the cpp
152   /// custom autograd Function as `needs_input_grad`.
153   bool needs_input_grad(size_t output_edge_index) const;
154   bool needs_input_grad(std::initializer_list<IndexRange> idxs) const;
155 
156  private:
157   std::unordered_set<at::TensorImpl*> non_differentiable_;
158   std::unordered_set<at::TensorImpl*> dirty_inputs_;
159   std::vector<torch::autograd::SavedVariable> saved_variables_;
160   variable_list to_save_;
161   bool materialize_grads_{true};
162 
163   // The CppNode in the autograd graph that owns this AutogradContext. We need a
164   // weak_ptr to avoid a refcycle. Since grad_fn_ owns this AutogradContext, it
165   // will always be alive when we want to use it.
166   std::weak_ptr<Node> grad_fn_;
167   bool has_freed_buffers_{false};
168 
169   void save_variables();
170 
171   template <class T>
172   friend struct CppNode;
173 };
174 
175 // CppNode<T> is the Node in the autograd graph that represents the user defined
176 // backward function for Function<T>. Calls to CppNode::apply are forward to
177 // T::backward().
178 template <class T>
179 struct CppNode : public Node {
180   variable_list apply(variable_list&& inputs) override;
181   AutogradContext ctx_;
182   std::vector<bool> is_variable_input_;
183   std::vector<VariableInfo> input_info_;
184   std::vector<VariableInfo> output_info_;
185 
186   void release_variables() override;
187 
188   void set_ctx_grad_fn(const std::shared_ptr<Node>& node);
189   void save_variables_to_ctx();
190 
compiled_argsCppNode191   void compiled_args(CompiledNodeArgs& args) override {
192     static_assert(
193         std::is_same_v<std::remove_cv_t<decltype(T::is_traceable)>, bool>);
194     if (!T::is_traceable) {
195       throw std::runtime_error(
196           std::string(
197               "Attempting to trace a potentially unsafe C++ autograd function: ") +
198           name() +
199           ". It may be possible to trace it safely, please refer to the instructions in: https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY/.");
200     }
201 
202     // although neither of the 2 methods below have uniqueness guarantees
203     // it is unlikely for them to collide at the same time
204     args.collect(static_cast<uint64_t>(typeid(T).hash_code()));
205     args.collect(std::string(typeid(T).name()));
206 
207     args.collect(ctx_.saved_data);
208     TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
209     TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
210     args.collect(
211         ctx_.saved_variables_, true); // always unpacked as output in eager
212     TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
213     args.collect(ctx_.materialize_grads_);
214     args.collect(ctx_.has_freed_buffers_);
215     args.collect(is_variable_input_);
216     args.collect(input_info_);
217     args.collect(output_info_);
218   }
219 
apply_with_savedCppNode220   variable_list apply_with_saved(
221       const variable_list& inputs,
222       SwapSavedVariables& saved) override {
223     saved.before(ctx_.saved_data);
224     TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
225     TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
226     saved.before(ctx_.saved_variables_);
227     TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
228     saved.before(ctx_.materialize_grads_);
229     saved.before(ctx_.has_freed_buffers_);
230     saved.before(input_info_);
231     saved.before(output_info_);
232     auto results = apply(variable_list(inputs));
233     saved.after(ctx_.saved_data);
234     TORCH_INTERNAL_ASSERT(ctx_.non_differentiable_.empty());
235     TORCH_INTERNAL_ASSERT(ctx_.dirty_inputs_.empty());
236     saved.after(ctx_.saved_variables_);
237     TORCH_INTERNAL_ASSERT(ctx_.to_save_.empty());
238     saved.after(ctx_.materialize_grads_);
239     saved.after(ctx_.has_freed_buffers_);
240     saved.after(input_info_);
241     saved.after(output_info_);
242     return results;
243   }
244 };
245 
246 struct ExtractVariables : IterArgs<ExtractVariables> {
247   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
248   std::vector<bool>& is_var_;
249   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
250   variable_list& list_;
ExtractVariablesExtractVariables251   ExtractVariables(std::vector<bool>& is_var, variable_list& list)
252       : is_var_(is_var), list_(list) {}
operatorExtractVariables253   void operator()(const std::optional<at::Tensor>& x) {
254     if (x.has_value() && x.value().defined()) {
255       is_var_.push_back(true);
256       list_.emplace_back(x.value());
257     } else {
258       is_var_.push_back(false);
259     }
260   }
operatorExtractVariables261   void operator()(const at::Tensor& x) {
262     is_var_.push_back(true);
263     list_.emplace_back(x);
264   }
operatorExtractVariables265   void operator()(const at::TensorList& list) {
266     for (const at::Tensor& x : list) {
267       is_var_.push_back(true);
268       list_.emplace_back(x);
269     }
270   }
271   template <typename T>
operatorExtractVariables272   void operator()(const T& x) {
273     is_var_.push_back(false);
274   }
275 };
276 
277 template <typename... Args>
extract_vars(std::vector<bool> & is_var,variable_list & list,Args &&...args)278 inline void extract_vars(
279     std::vector<bool>& is_var,
280     variable_list& list,
281     Args&&... args) {
282   ExtractVariables(is_var, list).apply(std::forward<Args>(args)...);
283 }
284 
285 template <typename T>
to_output_type(std::vector<std::optional<Variable>> & output_list)286 std::enable_if_t<std::is_same_v<T, variable_list>, T> to_output_type(
287     std::vector<std::optional<Variable>>& output_list) {
288   variable_list result;
289   std::transform(
290       output_list.begin(),
291       output_list.end(),
292       std::back_inserter(result),
293       [](const std::optional<Variable>& var) { return *var; });
294   return result;
295 }
296 
297 template <typename T>
to_output_type(std::vector<std::optional<Variable>> & output_list)298 std::enable_if_t<std::is_same_v<T, Variable>, T> to_output_type(
299     std::vector<std::optional<Variable>>& output_list) {
300   return *output_list[0];
301 }
302 
to_optional(Variable & output)303 inline std::vector<std::optional<Variable>> to_optional(Variable& output) {
304   return std::vector<std::optional<Variable>>{output};
305 }
306 
to_optional(variable_list & output)307 inline std::vector<std::optional<Variable>> to_optional(variable_list& output) {
308   std::vector<std::optional<Variable>> result;
309   std::transform(
310       output.begin(),
311       output.end(),
312       std::back_inserter(result),
313       [](const Variable& var) { return var; });
314   return result;
315 }
316 
317 template <class T>
318 template <typename X, typename... Args>
319 auto Function<T>::apply(Args&&... args)
320     -> std::enable_if_t<std::is_same_v<X, T>, forward_t<X, Args...>> {
321   const auto& functorch_tls = at::functorch::functorchTLSAccessor();
322   if (functorch_tls) {
323     // Function support for functorch is handled in Python.
324     // Here we are dealing with a (C++) Function, which is not supported.
325     // Let's raise an error instead of being silently incorrect.
326     functorch_tls->checkSupportsCppAutogradFunction();
327   }
328 
329   std::shared_ptr<CppNode<T>> node(new CppNode<T>(), deleteNode);
330   variable_list input_vars;
331 
332   const size_t num_inputs = sizeof...(Args);
333   input_vars.reserve(num_inputs);
334   node->is_variable_input_.reserve(num_inputs);
335   // TODO Add tracing here
336   extract_vars(node->is_variable_input_, input_vars, args...);
337 
338   bool is_executable =
339       GradMode::is_enabled() && any_variable_requires_grad(input_vars);
340   auto next_edges =
341       (is_executable ? collect_next_edges(input_vars) : edge_list());
342   node->set_ctx_grad_fn(node);
343   node->set_next_edges(std::move(next_edges));
344   node->clear_input_metadata();
345 
346   node->input_info_.reserve(input_vars.size());
347   for (auto& var : input_vars) {
348     node->input_info_.emplace_back(var);
349   }
350 
351   using forward_return_t = forward_t<X, Args...>;
352   forward_return_t outputs;
353   {
354     AutoGradMode grad_mode(false);
355     outputs = T::forward(&node->ctx_, std::forward<Args>(args)...);
356   }
357 
358   _jvp_fn_t jvp_fn = [](const variable_list& inputs,
359                         const variable_list& gI) -> variable_list {
360     TORCH_CHECK(
361         false,
362         "jvp is not implemented for the c++ API of custom Function yet.",
363         "Please open a feature request on GitHub if you need this.");
364   };
365 
366   auto view_as_self_fn = [](const at::Tensor& x) -> at::Tensor {
367     return x.view_as(x);
368   };
369 
370   auto wrapped_outputs = _wrap_outputs(
371       input_vars,
372       node->ctx_.get_non_differentiable(),
373       node->ctx_.get_and_bump_dirty(),
374       to_optional(outputs),
375       is_executable ? node : nullptr,
376       jvp_fn,
377       {},
378       view_as_self_fn);
379 
380   node->output_info_.reserve(wrapped_outputs.size());
381   for (auto& output : wrapped_outputs) {
382     if (is_executable && output.has_value()) {
383       node->output_info_.emplace_back(output.value());
384     } else if (is_executable) {
385       node->output_info_.emplace_back();
386     }
387   }
388 
389   if (is_executable) {
390     node->save_variables_to_ctx();
391   }
392 
393   // wrapped_outputs will be a variable_list so, convert it to the correct
394   // return type. Only Variable and variable_list are accepted as return types.
395   return to_output_type<forward_return_t>(wrapped_outputs);
396 }
397 
398 // The logic here is the same as PyNode::apply, so changes to it should be done
399 // in both the places
400 template <class T>
401 // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved)
apply(variable_list && inputs)402 variable_list CppNode<T>::apply(variable_list&& inputs) {
403   at::OptionalDeviceGuard _device_guard;
404 
405   auto num_inputs = inputs.size();
406   variable_list backward_inputs;
407   backward_inputs.reserve(num_inputs);
408   for (const auto i : c10::irange(num_inputs)) {
409     if (inputs[i].defined() || !ctx_.materialize_grads_) {
410       backward_inputs.emplace_back(std::move(inputs[i]));
411     } else {
412       backward_inputs.emplace_back(output_info_[i].zeros(_device_guard));
413     }
414   }
415 
416   // Acquire lock to here protect thread safety on custom C++ Autograd Node
417   // This is needed for the custom Autograd Node since we don't know if the
418   // user defined Node will write to the shared data during backward.
419   // see Note [Thread Safety on Autograd Node]
420   std::lock_guard<std::mutex> lock(mutex_);
421 
422   auto outputs = T::backward(&ctx_, backward_inputs);
423 
424   const auto num_forward_inputs =
425       static_cast<int64_t>(is_variable_input_.size());
426   auto num_outputs = static_cast<int64_t>(outputs.size());
427   // Returning too many results is ok, but only as long as they're all
428   // undefined. Truncate the result vector in that case.
429   if (num_outputs > num_forward_inputs) {
430     bool all_undef = true;
431     for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
432       all_undef &= (!outputs[i].defined());
433     }
434     if (all_undef) {
435       outputs.resize(num_forward_inputs);
436       num_outputs = num_forward_inputs;
437     }
438   }
439 
440   if (num_outputs != num_forward_inputs) {
441     std::string msg("function ");
442     msg += name() + " returned an incorrect number of gradients (expected ";
443     msg += std::to_string(num_forward_inputs) + ", got ";
444     msg += std::to_string(num_outputs) + ")";
445     throw std::runtime_error(msg);
446   }
447 
448   variable_list results;
449   results.reserve(num_outputs);
450   for (const auto i : c10::irange(num_outputs)) {
451     if (!is_variable_input_[i]) {
452       if (outputs[i].defined()) {
453         std::string msg("function ");
454         msg += name() +
455             " returned a gradient different that is defined at position ";
456         msg += std::to_string(i + 1) +
457             ", std the corresponding forward input was not a Variable";
458         throw std::runtime_error(msg);
459       }
460       continue;
461     }
462     results.emplace_back(outputs[i]);
463   }
464   return results;
465 }
466 
467 template <class T>
release_variables()468 void CppNode<T>::release_variables() {
469   // lock to ensure thread safety, see [Thread Safety on Autograd Node]
470   std::lock_guard<std::mutex> lock(mutex_);
471   ctx_.saved_variables_.clear();
472   ctx_.has_freed_buffers_ = true;
473 }
474 
475 template <class T>
save_variables_to_ctx()476 void CppNode<T>::save_variables_to_ctx() {
477   ctx_.save_variables();
478 }
479 
480 template <class T>
set_ctx_grad_fn(const std::shared_ptr<Node> & node)481 void CppNode<T>::set_ctx_grad_fn(const std::shared_ptr<Node>& node) {
482   ctx_.grad_fn_ = node;
483 }
484 
485 } // namespace torch::autograd
486