xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/VariableTypeUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/irange.h>
4 
5 #include <ATen/core/boxing/KernelFunction.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7 
8 #include <torch/csrc/autograd/edge.h>
9 #include <torch/csrc/autograd/function.h>
10 #include <torch/csrc/autograd/functions/basic_ops.h>
11 #include <torch/csrc/autograd/functions/tensor.h>
12 #include <torch/csrc/autograd/grad_mode.h>
13 #include <torch/csrc/autograd/saved_variable.h>
14 #include <torch/csrc/autograd/variable.h>
15 
16 #include <torch/csrc/autograd/functions/utils.h>
17 #include <torch/csrc/autograd/jit_decomp_interface.h>
18 #include <torch/csrc/utils/variadic.h>
19 
20 #include <cstddef>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 
26 #ifdef _MSC_VER
27 #ifdef Type
28 #undef Type
29 #endif
30 #endif
31 
32 namespace torch::autograd {
33 enum class can_mutate_inplace_result {
34   success,
35   non_default_backward_view,
36   view_of_leaf,
37   is_leaf,
38 };
39 
40 // The requires_grad argument is used to know if the inplace operation needs
41 // gradient to be setup for it.
42 // In particular, we can have tensor.requires_grad() != requires_grad when
43 // writing a Tensor that requires gradients inplace into a Tensor that does not
44 // require gradients: a = torch.rand(2) b = torch.rand(2, requires_grad=True)
45 // a.copy_(b)
can_mutate_inplace(const at::Tensor & tensor,bool requires_grad)46 inline can_mutate_inplace_result can_mutate_inplace(
47     const at::Tensor& tensor,
48     bool requires_grad) {
49   if (!requires_grad || !GradMode::is_enabled()) {
50     return can_mutate_inplace_result::success;
51   }
52   auto diff_view_meta = impl::get_view_autograd_meta(tensor);
53   if (diff_view_meta && diff_view_meta->has_bw_view()) {
54     if (diff_view_meta->get_creation_meta() != CreationMeta::DEFAULT) {
55       return can_mutate_inplace_result::non_default_backward_view;
56     }
57     if (tensor.requires_grad() && tensor._base().is_leaf()) {
58       return can_mutate_inplace_result::view_of_leaf;
59     }
60   }
61   if (tensor.requires_grad() && tensor.is_leaf()) {
62     return can_mutate_inplace_result::is_leaf;
63   }
64   return can_mutate_inplace_result::success;
65 }
66 
check_inplace(const at::Tensor & tensor,bool requires_grad)67 inline void check_inplace(const at::Tensor& tensor, bool requires_grad) {
68   switch (can_mutate_inplace(tensor, requires_grad)) {
69     case can_mutate_inplace_result::success:
70       return;
71     case can_mutate_inplace_result::non_default_backward_view: {
72       return handle_view_on_rebase(impl::get_view_autograd_meta(tensor));
73     }
74     case can_mutate_inplace_result::view_of_leaf:
75       TORCH_CHECK(
76           false,
77           "a view of a leaf Variable that requires grad is being used in an in-place operation.");
78       break;
79 
80     case can_mutate_inplace_result::is_leaf:
81       TORCH_CHECK(
82           false,
83           "a leaf Variable that requires grad is being used in an in-place operation.");
84       break;
85   }
86   TORCH_INTERNAL_ASSERT(false);
87 }
88 
check_inplace(at::ITensorListRef tensors,bool requires_grad)89 inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) {
90   for (const auto& tensor : tensors) {
91     check_inplace(tensor, requires_grad);
92   }
93 }
94 
throw_error_out_requires_grad(const char * name)95 inline void throw_error_out_requires_grad(const char* name) {
96   AT_ERROR(
97       name,
98       "(): functions with out=... arguments don't support automatic differentiation, "
99       "but one of the arguments requires grad.");
100 }
101 
throw_error_for_complex_autograd(const at::Tensor & tensor,const char * name)102 inline void throw_error_for_complex_autograd(
103     const at::Tensor& tensor,
104     const char* name) {
105   if (tensor.requires_grad()) {
106     TORCH_CHECK(
107         !tensor.is_complex(),
108         name,
109         " does not support automatic differentiation for outputs with complex dtype.");
110   }
111 }
112 
throw_error_if_base_and_tensor_are_same(const at::Tensor & base,const at::Tensor & tensor)113 inline void throw_error_if_base_and_tensor_are_same(
114     const at::Tensor& base,
115     const at::Tensor& tensor) {
116   TORCH_CHECK(
117       base.unsafeGetTensorImpl() != tensor.unsafeGetTensorImpl(),
118       "View operation returned a tensor that is the same as the input base tensor.  This "
119       "is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). "
120       "As a user, you could have made a mistake implementing __torch_dispatch__ or a Python "
121       "operator decomposition or meta registration; if that's not the case, please "
122       "report a bug to PyTorch or the backend you are using.");
123 }
124 
throw_error_for_complex_autograd(at::ITensorListRef tensorlist,const char * name)125 inline void throw_error_for_complex_autograd(
126     at::ITensorListRef tensorlist,
127     const char* name) {
128   for (const auto& tensor : tensorlist) {
129     throw_error_for_complex_autograd(tensor, name);
130   }
131 }
132 
133 // TODO: Blegh, bare references
134 
rebase_history(const Variable & var,std::shared_ptr<Node> grad_fn)135 inline void rebase_history(const Variable& var, std::shared_ptr<Node> grad_fn) {
136   if (grad_fn && var.defined()) {
137     grad_fn->add_input_metadata(var);
138     impl::rebase_history(var, {std::move(grad_fn), 0});
139   }
140 }
141 
rebase_history(const std::vector<Variable> & vars,const std::shared_ptr<Node> & grad_fn)142 inline void rebase_history(
143     const std::vector<Variable>& vars,
144     const std::shared_ptr<Node>& grad_fn) {
145   if (grad_fn) {
146     for (auto& var : vars) {
147       if (var.defined()) {
148         auto output_nr = grad_fn->add_input_metadata(var);
149         impl::rebase_history(var, {grad_fn, output_nr});
150       } else {
151         grad_fn->add_input_metadata(Node::undefined_input());
152       }
153     }
154   }
155 }
156 
increment_version(const at::Tensor & t)157 inline void increment_version(const at::Tensor& t) {
158   impl::bump_version(t);
159 }
160 
161 struct Flatten : IterArgs<Flatten> {
FlattenFlatten162   Flatten(variable_list& out) : out(out) {}
163   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
164   variable_list& out;
operatorFlatten165   void operator()(const at::Tensor& x) {
166     out.emplace_back(x);
167   }
operatorFlatten168   void operator()(const std::optional<at::Tensor>& x) {
169     if (x.has_value())
170       out.emplace_back(x.value());
171   }
operatorFlatten172   void operator()(at::ArrayRef<at::Tensor> xs) {
173     out.insert(out.end(), xs.begin(), xs.end());
174   }
175 };
176 
177 template <typename... Args>
flatten_tensor_args(Args &&...args)178 inline variable_list flatten_tensor_args(Args&&... args) {
179   variable_list out;
180   out.reserve(count_tensors(std::forward<Args>(args)...));
181   Flatten(out).apply(std::forward<Args>(args)...);
182   return out; // RVO
183 }
184 
185 // See NOTE [ Autograd View Variables ] for details.
186 inline at::Tensor as_view(
187     const at::Tensor& base,
188     const at::Tensor& tensor,
189     bool is_bw_differentiable,
190     bool is_fw_differentiable,
191     std::unique_ptr<ViewFunc> view_func = nullptr,
192     std::function<at::Tensor(const at::Tensor&)> rev_view_func = nullptr,
193     CreationMeta creation_meta = CreationMeta::DEFAULT,
194     bool allow_tensor_metadata_change = true) {
195   // Note [View of inference tensor]
196   // For inference tensor this code can only be hit outside InferenceMode
197   // since ADInplaceOrView is in the default_included_set.
198   // If Inplace and View were separate dispatch keys we can just put Inplace
199   // in the default_included_set, so that view ops on inference tensor doesn't
200   // have to go through as_view even outside InferenceMode.
201   if (base.is_inference())
202     return tensor;
203 
204   auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
205 
206   // To speed up the most common case, we specially handle when both the forward
207   // and backward view infos are the same, and so a single shared ViewInfo can
208   // be used for both of them.
209   if ((!diff_view_meta || diff_view_meta->shared_view_info()) &&
210       is_bw_differentiable && is_fw_differentiable) {
211     throw_error_if_base_and_tensor_are_same(base, tensor);
212     if (diff_view_meta) {
213       creation_meta = propagate_creation_meta(
214           diff_view_meta->get_creation_meta(), creation_meta);
215       return make_variable_differentiable_view(
216           tensor,
217           diff_view_meta->get_backward_view().chain(
218               base, tensor, std::move(view_func), std::move(rev_view_func)),
219           std::nullopt,
220           /*shared_view_info*/ true,
221           creation_meta,
222           allow_tensor_metadata_change);
223     } else {
224       return make_variable_differentiable_view(
225           tensor,
226           ViewInfo(base, std::move(view_func), std::move(rev_view_func)),
227           std::nullopt,
228           /*shared_view_info*/ true,
229           creation_meta,
230           allow_tensor_metadata_change);
231     }
232   }
233 
234   // If they cannot be shared, create the required view infos
235   std::optional<ViewInfo> new_bw_info;
236   std::optional<ViewInfo> new_fw_info;
237 
238   if (is_bw_differentiable) {
239     auto bw_view_func = view_func ? view_func->clone_and_set() : nullptr;
240     if (diff_view_meta && diff_view_meta->has_bw_view()) {
241       const auto& base_bw_info = diff_view_meta->get_backward_view();
242       new_bw_info = base_bw_info.chain(
243           base, tensor, std::move(bw_view_func), rev_view_func);
244     } else {
245       new_bw_info = ViewInfo(base, std::move(bw_view_func), rev_view_func);
246     }
247   } else {
248     TORCH_CHECK(
249         creation_meta == CreationMeta::DEFAULT,
250         "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
251   }
252 
253   if (is_fw_differentiable) {
254     // Check if base is a forward differentiable view
255     if (diff_view_meta && diff_view_meta->has_fw_view()) {
256       const auto& base_fw_info = diff_view_meta->get_forward_view();
257       new_fw_info = base_fw_info.chain(
258           base, tensor, std::move(view_func), std::move(rev_view_func));
259     } else {
260       new_fw_info =
261           ViewInfo(base, std::move(view_func), std::move(rev_view_func));
262     }
263   }
264 
265   if (is_fw_differentiable || is_bw_differentiable) {
266     if (diff_view_meta && diff_view_meta->has_bw_view()) {
267       creation_meta = propagate_creation_meta(
268           diff_view_meta->get_creation_meta(), creation_meta);
269     }
270     throw_error_if_base_and_tensor_are_same(base, tensor);
271     return make_variable_differentiable_view(
272         tensor,
273         std::move(new_bw_info),
274         std::move(new_fw_info),
275         /*shared_view_info*/ false,
276         creation_meta,
277         allow_tensor_metadata_change);
278   } else {
279     return make_variable_non_differentiable_view(
280         base, tensor, allow_tensor_metadata_change);
281   }
282 }
283 
284 inline void check_no_requires_grad(
285     const at::Tensor& tensor,
286     const char* name,
287     const char* fn_name = "",
288     bool check_grad_mode = true) {
289   TORCH_CHECK(
290       !(tensor.defined() && tensor.requires_grad()) ||
291           !(check_grad_mode && GradMode::is_enabled()),
292       "The function '",
293       fn_name,
294       "' is not differentiable with respect to argument '",
295       name,
296       "'. This input cannot have requires_grad True.");
297 }
298 
299 inline void check_no_requires_grad(
300     const std::optional<at::Tensor>& tensor,
301     const char* name,
302     const char* fn_name = "") {
303   if (tensor.has_value()) {
304     check_no_requires_grad(*tensor, name, fn_name);
305   }
306 }
307 
308 inline void check_no_requires_grad(
309     at::ITensorListRef tensors,
310     const char* name,
311     const char* fn_name = "") {
312   // GradMode check is expensive, so check it only once for TensorLists
313   if (!GradMode::is_enabled()) {
314     return;
315   }
316   for (auto& tensor : tensors) {
317     check_no_requires_grad(tensor, name, fn_name, /*check_grad_mode*/ false);
318   }
319 }
320 
321 inline void check_no_requires_grad(
322     const c10::List<std::optional<at::Tensor>>& tensors,
323     const char* name,
324     const char* fn_name = "") {
325   // GradMode check is expensive, so check it only once for TensorLists
326   if (!GradMode::is_enabled()) {
327     return;
328   }
329   for (std::optional<at::Tensor> tensor : tensors) {
330     if (tensor.has_value()) {
331       check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false);
332     }
333   }
334 }
335 
336 // Assumed that saved tensor lists are never inplace outputs
337 inline std::vector<SavedVariable> make_saved_variable_list(
338     at::ITensorListRef tensors,
339     const bool is_output = false) {
340   return fmap(tensors, [&is_output](const at::Tensor& tensor) -> SavedVariable {
341     return SavedVariable{tensor, is_output /* is output */};
342   });
343 }
344 
345 // Assumed that saved tensor lists are never inplace outputs
346 inline std::vector<SavedVariable> make_saved_variable_list(
347     const c10::List<std::optional<at::Tensor>>& tensors,
348     const bool is_output = false) {
349   return fmap(
350       tensors,
351       [&is_output](const std::optional<at::Tensor>& tensor) -> SavedVariable {
352         if (tensor.has_value()) {
353           return SavedVariable{*tensor, is_output /* is output */};
354         } else {
355           return SavedVariable{at::Tensor(), is_output /* is output */};
356         }
357       });
358 }
359 
to_args_sizes(at::ITensorListRef tensors)360 inline std::vector<std::vector<int64_t>> to_args_sizes(
361     at::ITensorListRef tensors) {
362   std::vector<std::vector<int64_t>> args_sizes(tensors.size());
363   size_t i = 0;
364   for (const auto& t : tensors) {
365     args_sizes[i++] = t.sizes().vec();
366   }
367   return args_sizes;
368 }
369 
to_args_sizes_symint(at::ITensorListRef tensors)370 inline std::vector<std::vector<c10::SymInt>> to_args_sizes_symint(
371     at::ITensorListRef tensors) {
372   std::vector<std::vector<c10::SymInt>> args_sizes(tensors.size());
373   size_t i = 0;
374   for (const auto& t : tensors) {
375     args_sizes[i++] = t.sym_sizes().vec();
376   }
377   return args_sizes;
378 }
379 
to_args_scalartypes(at::ITensorListRef tensors)380 inline std::vector<c10::ScalarType> to_args_scalartypes(
381     at::ITensorListRef tensors) {
382   std::vector<c10::ScalarType> args_scalartypes(tensors.size());
383   size_t i = 0;
384   for (const auto& t : tensors) {
385     args_scalartypes[i++] = t.scalar_type();
386   }
387   return args_scalartypes;
388 }
389 
390 namespace impl {
391 
392 namespace {
393 
394 // If run_jit_decomposition were not a member function, we would be able
395 // to pass this as a template parameter to c10::Boxedkernel::makeFromFunction.
396 // However, member functions cannot be passed this way - instead we wrap our
397 // call in this functor so it can be passed to c10::BoxedKernel::makeFromFunctor
398 class WrapperFunctor final : public c10::OperatorKernel {
399  public:
WrapperFunctor(JitDecompInterface * impl)400   WrapperFunctor(JitDecompInterface* impl) : impl_(impl){};
401 
operator()402   void operator()(
403       const c10::OperatorHandle& op,
404       c10::DispatchKeySet ks,
405       torch::jit::Stack* stack) {
406     impl_->run_jit_decomposition(op, stack);
407   }
408   JitDecompInterface* impl_;
409 };
410 
411 } // namespace
412 
413 template <class Return, class... Args>
run_jit_decomposition_with_args_for_jvp(c10::string_view name,const c10::OperatorHandle & opHandle,c10::DispatchKeySet dispatchKeySet,Args &&...args)414 Return run_jit_decomposition_with_args_for_jvp(
415     c10::string_view name,
416     const c10::OperatorHandle& opHandle,
417     c10::DispatchKeySet dispatchKeySet,
418     Args&&... args) {
419   // see NOTE: [Jit Decomposition Interface]
420   JitDecompInterface* impl = getJitDecompImpl();
421 
422   TORCH_CHECK_NOT_IMPLEMENTED(
423       impl && impl->has_jit_decomposition(opHandle.schema()),
424       "Trying to use forward AD with ",
425       name,
426       " that does not support it because it has not been implemented yet.\nPlease file an issue "
427       "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
428       "so that we can prioritize its implementation or submit a PR adding the implementation to "
429       "derivatives.yaml");
430 
431   return c10::KernelFunction::makeFromBoxedKernel(
432              c10::BoxedKernel::makeFromFunctor(
433                  std::make_unique<WrapperFunctor>(impl)))
434       .call<Return, Args...>(
435           opHandle, dispatchKeySet, std::forward<Args>(args)...);
436 }
437 
438 } // namespace impl
439 
440 } // namespace torch::autograd
441