xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/adaption.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/core/List.h>
6 #include <c10/core/TensorOptions.h>
7 
8 /*
9  * [Note: hacky wrapper removal for optional tensor]
10  *
11  * The kernel implementation takes an optional tensor marked in the schema as
12  * Tensor? but the C++ function takes Tensor instead of the std::optional<Tensor>
13  * expected by the dispatcher.
14  *
15  * To remove the hacky wrapper, the C++ function is changed to take
16  * std::optional<Tensor> and unwrap the Tensor value at the beginning of
17  * the function, e.g.:
18  *   > c10::MaybeOwned<Tensor> weight_maybe_owned =
19  *   >     at::borrow_from_optional_tensor(weight_opt);
20  *   > const Tensor& weight = *weight_maybe_owned;
21  *
22  * We may want to make the kernel handle optional directly without
23  * going through the creation of a default-constructed Tensor in
24  * at::borrow_from_optional_tensor.
25  */
26 
27 /*
28  * [Note: hacky wrapper removal for TensorOptions]
29  *
30  * The kernel implementation takes a TensorOptions argument but the dispatcher
31  * expects separate arguments for dtype, layout, device, pin_memory.
32  *
33  * To remove the hacky wrapper, the kernel implementation is changed to take
34  * the 4 arguments (dtype, layout, device, pin_memory), and assemble the
35  * TensorOptions value at the beginning of the function, e.g.:
36  *   > TensorOptions options = TensorOptions().dtype(dtype).layout(layout)
37  *   >    .device(device).pinned_memory(pin_memory);
38  *
39  * We may want make the kernel handle these parameters directly without going
40  * through the creation of a TensorOptions value.
41  */
42 
43 namespace c10 {
44 namespace impl {
45 
46 TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
47 
check_and_update_common_device(std::optional<Device> & common_device,const at::Tensor & tensor,at::CheckedFrom methodName,at::CheckedFrom argName)48 inline void check_and_update_common_device(std::optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
49   // TODO: Remove this once the following issue is addressed:
50   // https://github.com/pytorch/pytorch/issues/57380
51   if (!tensor.defined()) {
52     return;
53   }
54 
55   if (!common_device.has_value()) {
56     common_device = tensor.device();
57     return;
58   }
59 
60   if (C10_UNLIKELY(common_device != tensor.device())) {
61     common_device_check_failure(*common_device, tensor, methodName, argName);
62   }
63 }
64 
check_and_update_common_device(std::optional<Device> & common_device,const std::optional<at::Tensor> & tensor,at::CheckedFrom methodName,at::CheckedFrom argName)65 inline void check_and_update_common_device(std::optional<Device>& common_device, const std::optional<at::Tensor>& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
66   if (tensor.has_value()) {
67     check_and_update_common_device(common_device, tensor.value(), methodName, argName);
68   }
69 }
70 
check_and_update_common_device(std::optional<Device> & common_device,at::ITensorListRef tensors,at::CheckedFrom methodName,at::CheckedFrom argName)71 inline void check_and_update_common_device(std::optional<Device>& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
72   for (const auto& tensor : tensors) {
73     check_and_update_common_device(common_device, tensor, methodName, argName);
74   }
75 }
76 
check_and_update_common_device(std::optional<Device> & common_device,const List<std::optional<at::Tensor>> & tensors,at::CheckedFrom methodName,at::CheckedFrom argName)77 inline void check_and_update_common_device(std::optional<Device>& common_device, const List<std::optional<at::Tensor>>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
78   for (const auto& tensor : tensors) {
79     check_and_update_common_device(common_device, tensor, methodName, argName);
80   }
81 }
82 } // namespace impl
83 } // namespace c10
84