1 #pragma once
2
3 #include <c10/util/irange.h>
4
5 #include <ATen/core/boxing/KernelFunction.h>
6 #include <ATen/core/dispatch/Dispatcher.h>
7
8 #include <torch/csrc/autograd/edge.h>
9 #include <torch/csrc/autograd/function.h>
10 #include <torch/csrc/autograd/functions/basic_ops.h>
11 #include <torch/csrc/autograd/functions/tensor.h>
12 #include <torch/csrc/autograd/grad_mode.h>
13 #include <torch/csrc/autograd/saved_variable.h>
14 #include <torch/csrc/autograd/variable.h>
15
16 #include <torch/csrc/autograd/functions/utils.h>
17 #include <torch/csrc/autograd/jit_decomp_interface.h>
18 #include <torch/csrc/utils/variadic.h>
19
20 #include <cstddef>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25
26 #ifdef _MSC_VER
27 #ifdef Type
28 #undef Type
29 #endif
30 #endif
31
32 namespace torch::autograd {
33 enum class can_mutate_inplace_result {
34 success,
35 non_default_backward_view,
36 view_of_leaf,
37 is_leaf,
38 };
39
40 // The requires_grad argument is used to know if the inplace operation needs
41 // gradient to be setup for it.
42 // In particular, we can have tensor.requires_grad() != requires_grad when
43 // writing a Tensor that requires gradients inplace into a Tensor that does not
44 // require gradients: a = torch.rand(2) b = torch.rand(2, requires_grad=True)
45 // a.copy_(b)
can_mutate_inplace(const at::Tensor & tensor,bool requires_grad)46 inline can_mutate_inplace_result can_mutate_inplace(
47 const at::Tensor& tensor,
48 bool requires_grad) {
49 if (!requires_grad || !GradMode::is_enabled()) {
50 return can_mutate_inplace_result::success;
51 }
52 auto diff_view_meta = impl::get_view_autograd_meta(tensor);
53 if (diff_view_meta && diff_view_meta->has_bw_view()) {
54 if (diff_view_meta->get_creation_meta() != CreationMeta::DEFAULT) {
55 return can_mutate_inplace_result::non_default_backward_view;
56 }
57 if (tensor.requires_grad() && tensor._base().is_leaf()) {
58 return can_mutate_inplace_result::view_of_leaf;
59 }
60 }
61 if (tensor.requires_grad() && tensor.is_leaf()) {
62 return can_mutate_inplace_result::is_leaf;
63 }
64 return can_mutate_inplace_result::success;
65 }
66
check_inplace(const at::Tensor & tensor,bool requires_grad)67 inline void check_inplace(const at::Tensor& tensor, bool requires_grad) {
68 switch (can_mutate_inplace(tensor, requires_grad)) {
69 case can_mutate_inplace_result::success:
70 return;
71 case can_mutate_inplace_result::non_default_backward_view: {
72 return handle_view_on_rebase(impl::get_view_autograd_meta(tensor));
73 }
74 case can_mutate_inplace_result::view_of_leaf:
75 TORCH_CHECK(
76 false,
77 "a view of a leaf Variable that requires grad is being used in an in-place operation.");
78 break;
79
80 case can_mutate_inplace_result::is_leaf:
81 TORCH_CHECK(
82 false,
83 "a leaf Variable that requires grad is being used in an in-place operation.");
84 break;
85 }
86 TORCH_INTERNAL_ASSERT(false);
87 }
88
check_inplace(at::ITensorListRef tensors,bool requires_grad)89 inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) {
90 for (const auto& tensor : tensors) {
91 check_inplace(tensor, requires_grad);
92 }
93 }
94
throw_error_out_requires_grad(const char * name)95 inline void throw_error_out_requires_grad(const char* name) {
96 AT_ERROR(
97 name,
98 "(): functions with out=... arguments don't support automatic differentiation, "
99 "but one of the arguments requires grad.");
100 }
101
throw_error_for_complex_autograd(const at::Tensor & tensor,const char * name)102 inline void throw_error_for_complex_autograd(
103 const at::Tensor& tensor,
104 const char* name) {
105 if (tensor.requires_grad()) {
106 TORCH_CHECK(
107 !tensor.is_complex(),
108 name,
109 " does not support automatic differentiation for outputs with complex dtype.");
110 }
111 }
112
throw_error_if_base_and_tensor_are_same(const at::Tensor & base,const at::Tensor & tensor)113 inline void throw_error_if_base_and_tensor_are_same(
114 const at::Tensor& base,
115 const at::Tensor& tensor) {
116 TORCH_CHECK(
117 base.unsafeGetTensorImpl() != tensor.unsafeGetTensorImpl(),
118 "View operation returned a tensor that is the same as the input base tensor. This "
119 "is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). "
120 "As a user, you could have made a mistake implementing __torch_dispatch__ or a Python "
121 "operator decomposition or meta registration; if that's not the case, please "
122 "report a bug to PyTorch or the backend you are using.");
123 }
124
throw_error_for_complex_autograd(at::ITensorListRef tensorlist,const char * name)125 inline void throw_error_for_complex_autograd(
126 at::ITensorListRef tensorlist,
127 const char* name) {
128 for (const auto& tensor : tensorlist) {
129 throw_error_for_complex_autograd(tensor, name);
130 }
131 }
132
133 // TODO: Blegh, bare references
134
rebase_history(const Variable & var,std::shared_ptr<Node> grad_fn)135 inline void rebase_history(const Variable& var, std::shared_ptr<Node> grad_fn) {
136 if (grad_fn && var.defined()) {
137 grad_fn->add_input_metadata(var);
138 impl::rebase_history(var, {std::move(grad_fn), 0});
139 }
140 }
141
rebase_history(const std::vector<Variable> & vars,const std::shared_ptr<Node> & grad_fn)142 inline void rebase_history(
143 const std::vector<Variable>& vars,
144 const std::shared_ptr<Node>& grad_fn) {
145 if (grad_fn) {
146 for (auto& var : vars) {
147 if (var.defined()) {
148 auto output_nr = grad_fn->add_input_metadata(var);
149 impl::rebase_history(var, {grad_fn, output_nr});
150 } else {
151 grad_fn->add_input_metadata(Node::undefined_input());
152 }
153 }
154 }
155 }
156
increment_version(const at::Tensor & t)157 inline void increment_version(const at::Tensor& t) {
158 impl::bump_version(t);
159 }
160
161 struct Flatten : IterArgs<Flatten> {
FlattenFlatten162 Flatten(variable_list& out) : out(out) {}
163 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
164 variable_list& out;
operatorFlatten165 void operator()(const at::Tensor& x) {
166 out.emplace_back(x);
167 }
operatorFlatten168 void operator()(const std::optional<at::Tensor>& x) {
169 if (x.has_value())
170 out.emplace_back(x.value());
171 }
operatorFlatten172 void operator()(at::ArrayRef<at::Tensor> xs) {
173 out.insert(out.end(), xs.begin(), xs.end());
174 }
175 };
176
177 template <typename... Args>
flatten_tensor_args(Args &&...args)178 inline variable_list flatten_tensor_args(Args&&... args) {
179 variable_list out;
180 out.reserve(count_tensors(std::forward<Args>(args)...));
181 Flatten(out).apply(std::forward<Args>(args)...);
182 return out; // RVO
183 }
184
185 // See NOTE [ Autograd View Variables ] for details.
186 inline at::Tensor as_view(
187 const at::Tensor& base,
188 const at::Tensor& tensor,
189 bool is_bw_differentiable,
190 bool is_fw_differentiable,
191 std::unique_ptr<ViewFunc> view_func = nullptr,
192 std::function<at::Tensor(const at::Tensor&)> rev_view_func = nullptr,
193 CreationMeta creation_meta = CreationMeta::DEFAULT,
194 bool allow_tensor_metadata_change = true) {
195 // Note [View of inference tensor]
196 // For inference tensor this code can only be hit outside InferenceMode
197 // since ADInplaceOrView is in the default_included_set.
198 // If Inplace and View were separate dispatch keys we can just put Inplace
199 // in the default_included_set, so that view ops on inference tensor doesn't
200 // have to go through as_view even outside InferenceMode.
201 if (base.is_inference())
202 return tensor;
203
204 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(base);
205
206 // To speed up the most common case, we specially handle when both the forward
207 // and backward view infos are the same, and so a single shared ViewInfo can
208 // be used for both of them.
209 if ((!diff_view_meta || diff_view_meta->shared_view_info()) &&
210 is_bw_differentiable && is_fw_differentiable) {
211 throw_error_if_base_and_tensor_are_same(base, tensor);
212 if (diff_view_meta) {
213 creation_meta = propagate_creation_meta(
214 diff_view_meta->get_creation_meta(), creation_meta);
215 return make_variable_differentiable_view(
216 tensor,
217 diff_view_meta->get_backward_view().chain(
218 base, tensor, std::move(view_func), std::move(rev_view_func)),
219 std::nullopt,
220 /*shared_view_info*/ true,
221 creation_meta,
222 allow_tensor_metadata_change);
223 } else {
224 return make_variable_differentiable_view(
225 tensor,
226 ViewInfo(base, std::move(view_func), std::move(rev_view_func)),
227 std::nullopt,
228 /*shared_view_info*/ true,
229 creation_meta,
230 allow_tensor_metadata_change);
231 }
232 }
233
234 // If they cannot be shared, create the required view infos
235 std::optional<ViewInfo> new_bw_info;
236 std::optional<ViewInfo> new_fw_info;
237
238 if (is_bw_differentiable) {
239 auto bw_view_func = view_func ? view_func->clone_and_set() : nullptr;
240 if (diff_view_meta && diff_view_meta->has_bw_view()) {
241 const auto& base_bw_info = diff_view_meta->get_backward_view();
242 new_bw_info = base_bw_info.chain(
243 base, tensor, std::move(bw_view_func), rev_view_func);
244 } else {
245 new_bw_info = ViewInfo(base, std::move(bw_view_func), rev_view_func);
246 }
247 } else {
248 TORCH_CHECK(
249 creation_meta == CreationMeta::DEFAULT,
250 "Non-backward differentiable views must have creation_meta=CreationMeta::DEFAULT");
251 }
252
253 if (is_fw_differentiable) {
254 // Check if base is a forward differentiable view
255 if (diff_view_meta && diff_view_meta->has_fw_view()) {
256 const auto& base_fw_info = diff_view_meta->get_forward_view();
257 new_fw_info = base_fw_info.chain(
258 base, tensor, std::move(view_func), std::move(rev_view_func));
259 } else {
260 new_fw_info =
261 ViewInfo(base, std::move(view_func), std::move(rev_view_func));
262 }
263 }
264
265 if (is_fw_differentiable || is_bw_differentiable) {
266 if (diff_view_meta && diff_view_meta->has_bw_view()) {
267 creation_meta = propagate_creation_meta(
268 diff_view_meta->get_creation_meta(), creation_meta);
269 }
270 throw_error_if_base_and_tensor_are_same(base, tensor);
271 return make_variable_differentiable_view(
272 tensor,
273 std::move(new_bw_info),
274 std::move(new_fw_info),
275 /*shared_view_info*/ false,
276 creation_meta,
277 allow_tensor_metadata_change);
278 } else {
279 return make_variable_non_differentiable_view(
280 base, tensor, allow_tensor_metadata_change);
281 }
282 }
283
284 inline void check_no_requires_grad(
285 const at::Tensor& tensor,
286 const char* name,
287 const char* fn_name = "",
288 bool check_grad_mode = true) {
289 TORCH_CHECK(
290 !(tensor.defined() && tensor.requires_grad()) ||
291 !(check_grad_mode && GradMode::is_enabled()),
292 "The function '",
293 fn_name,
294 "' is not differentiable with respect to argument '",
295 name,
296 "'. This input cannot have requires_grad True.");
297 }
298
299 inline void check_no_requires_grad(
300 const std::optional<at::Tensor>& tensor,
301 const char* name,
302 const char* fn_name = "") {
303 if (tensor.has_value()) {
304 check_no_requires_grad(*tensor, name, fn_name);
305 }
306 }
307
308 inline void check_no_requires_grad(
309 at::ITensorListRef tensors,
310 const char* name,
311 const char* fn_name = "") {
312 // GradMode check is expensive, so check it only once for TensorLists
313 if (!GradMode::is_enabled()) {
314 return;
315 }
316 for (auto& tensor : tensors) {
317 check_no_requires_grad(tensor, name, fn_name, /*check_grad_mode*/ false);
318 }
319 }
320
321 inline void check_no_requires_grad(
322 const c10::List<std::optional<at::Tensor>>& tensors,
323 const char* name,
324 const char* fn_name = "") {
325 // GradMode check is expensive, so check it only once for TensorLists
326 if (!GradMode::is_enabled()) {
327 return;
328 }
329 for (std::optional<at::Tensor> tensor : tensors) {
330 if (tensor.has_value()) {
331 check_no_requires_grad(*tensor, name, fn_name, /*check_grad_mode*/ false);
332 }
333 }
334 }
335
336 // Assumed that saved tensor lists are never inplace outputs
337 inline std::vector<SavedVariable> make_saved_variable_list(
338 at::ITensorListRef tensors,
339 const bool is_output = false) {
340 return fmap(tensors, [&is_output](const at::Tensor& tensor) -> SavedVariable {
341 return SavedVariable{tensor, is_output /* is output */};
342 });
343 }
344
345 // Assumed that saved tensor lists are never inplace outputs
346 inline std::vector<SavedVariable> make_saved_variable_list(
347 const c10::List<std::optional<at::Tensor>>& tensors,
348 const bool is_output = false) {
349 return fmap(
350 tensors,
351 [&is_output](const std::optional<at::Tensor>& tensor) -> SavedVariable {
352 if (tensor.has_value()) {
353 return SavedVariable{*tensor, is_output /* is output */};
354 } else {
355 return SavedVariable{at::Tensor(), is_output /* is output */};
356 }
357 });
358 }
359
to_args_sizes(at::ITensorListRef tensors)360 inline std::vector<std::vector<int64_t>> to_args_sizes(
361 at::ITensorListRef tensors) {
362 std::vector<std::vector<int64_t>> args_sizes(tensors.size());
363 size_t i = 0;
364 for (const auto& t : tensors) {
365 args_sizes[i++] = t.sizes().vec();
366 }
367 return args_sizes;
368 }
369
to_args_sizes_symint(at::ITensorListRef tensors)370 inline std::vector<std::vector<c10::SymInt>> to_args_sizes_symint(
371 at::ITensorListRef tensors) {
372 std::vector<std::vector<c10::SymInt>> args_sizes(tensors.size());
373 size_t i = 0;
374 for (const auto& t : tensors) {
375 args_sizes[i++] = t.sym_sizes().vec();
376 }
377 return args_sizes;
378 }
379
to_args_scalartypes(at::ITensorListRef tensors)380 inline std::vector<c10::ScalarType> to_args_scalartypes(
381 at::ITensorListRef tensors) {
382 std::vector<c10::ScalarType> args_scalartypes(tensors.size());
383 size_t i = 0;
384 for (const auto& t : tensors) {
385 args_scalartypes[i++] = t.scalar_type();
386 }
387 return args_scalartypes;
388 }
389
390 namespace impl {
391
392 namespace {
393
394 // If run_jit_decomposition were not a member function, we would be able
395 // to pass this as a template parameter to c10::Boxedkernel::makeFromFunction.
396 // However, member functions cannot be passed this way - instead we wrap our
397 // call in this functor so it can be passed to c10::BoxedKernel::makeFromFunctor
398 class WrapperFunctor final : public c10::OperatorKernel {
399 public:
WrapperFunctor(JitDecompInterface * impl)400 WrapperFunctor(JitDecompInterface* impl) : impl_(impl){};
401
operator()402 void operator()(
403 const c10::OperatorHandle& op,
404 c10::DispatchKeySet ks,
405 torch::jit::Stack* stack) {
406 impl_->run_jit_decomposition(op, stack);
407 }
408 JitDecompInterface* impl_;
409 };
410
411 } // namespace
412
413 template <class Return, class... Args>
run_jit_decomposition_with_args_for_jvp(c10::string_view name,const c10::OperatorHandle & opHandle,c10::DispatchKeySet dispatchKeySet,Args &&...args)414 Return run_jit_decomposition_with_args_for_jvp(
415 c10::string_view name,
416 const c10::OperatorHandle& opHandle,
417 c10::DispatchKeySet dispatchKeySet,
418 Args&&... args) {
419 // see NOTE: [Jit Decomposition Interface]
420 JitDecompInterface* impl = getJitDecompImpl();
421
422 TORCH_CHECK_NOT_IMPLEMENTED(
423 impl && impl->has_jit_decomposition(opHandle.schema()),
424 "Trying to use forward AD with ",
425 name,
426 " that does not support it because it has not been implemented yet.\nPlease file an issue "
427 "to PyTorch at https://github.com/pytorch/pytorch/issues/new?template=feature-request.yml "
428 "so that we can prioritize its implementation or submit a PR adding the implementation to "
429 "derivatives.yaml");
430
431 return c10::KernelFunction::makeFromBoxedKernel(
432 c10::BoxedKernel::makeFromFunctor(
433 std::make_unique<WrapperFunctor>(impl)))
434 .call<Return, Args...>(
435 opHandle, dispatchKeySet, std::forward<Args>(args)...);
436 }
437
438 } // namespace impl
439
440 } // namespace torch::autograd
441