1 #pragma once
2
3 #include <ATen/Tensor.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/core/List.h>
6 #include <c10/core/TensorOptions.h>
7
8 /*
9 * [Note: hacky wrapper removal for optional tensor]
10 *
11 * The kernel implementation takes an optional tensor marked in the schema as
12 * Tensor? but the C++ function takes Tensor instead of the std::optional<Tensor>
13 * expected by the dispatcher.
14 *
15 * To remove the hacky wrapper, the C++ function is changed to take
16 * std::optional<Tensor> and unwrap the Tensor value at the beginning of
17 * the function, e.g.:
18 * > c10::MaybeOwned<Tensor> weight_maybe_owned =
19 * > at::borrow_from_optional_tensor(weight_opt);
20 * > const Tensor& weight = *weight_maybe_owned;
21 *
22 * We may want to make the kernel handle optional directly without
23 * going through the creation of a default-constructed Tensor in
24 * at::borrow_from_optional_tensor.
25 */
26
27 /*
28 * [Note: hacky wrapper removal for TensorOptions]
29 *
30 * The kernel implementation takes a TensorOptions argument but the dispatcher
31 * expects separate arguments for dtype, layout, device, pin_memory.
32 *
33 * To remove the hacky wrapper, the kernel implementation is changed to take
34 * the 4 arguments (dtype, layout, device, pin_memory), and assemble the
35 * TensorOptions value at the beginning of the function, e.g.:
36 * > TensorOptions options = TensorOptions().dtype(dtype).layout(layout)
37 * > .device(device).pinned_memory(pin_memory);
38 *
39 * We may want make the kernel handle these parameters directly without going
40 * through the creation of a TensorOptions value.
41 */
42
43 namespace c10 {
44 namespace impl {
45
46 TORCH_API void common_device_check_failure(Device common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName);
47
check_and_update_common_device(std::optional<Device> & common_device,const at::Tensor & tensor,at::CheckedFrom methodName,at::CheckedFrom argName)48 inline void check_and_update_common_device(std::optional<Device>& common_device, const at::Tensor& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
49 // TODO: Remove this once the following issue is addressed:
50 // https://github.com/pytorch/pytorch/issues/57380
51 if (!tensor.defined()) {
52 return;
53 }
54
55 if (!common_device.has_value()) {
56 common_device = tensor.device();
57 return;
58 }
59
60 if (C10_UNLIKELY(common_device != tensor.device())) {
61 common_device_check_failure(*common_device, tensor, methodName, argName);
62 }
63 }
64
check_and_update_common_device(std::optional<Device> & common_device,const std::optional<at::Tensor> & tensor,at::CheckedFrom methodName,at::CheckedFrom argName)65 inline void check_and_update_common_device(std::optional<Device>& common_device, const std::optional<at::Tensor>& tensor, at::CheckedFrom methodName, at::CheckedFrom argName) {
66 if (tensor.has_value()) {
67 check_and_update_common_device(common_device, tensor.value(), methodName, argName);
68 }
69 }
70
check_and_update_common_device(std::optional<Device> & common_device,at::ITensorListRef tensors,at::CheckedFrom methodName,at::CheckedFrom argName)71 inline void check_and_update_common_device(std::optional<Device>& common_device, at::ITensorListRef tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
72 for (const auto& tensor : tensors) {
73 check_and_update_common_device(common_device, tensor, methodName, argName);
74 }
75 }
76
check_and_update_common_device(std::optional<Device> & common_device,const List<std::optional<at::Tensor>> & tensors,at::CheckedFrom methodName,at::CheckedFrom argName)77 inline void check_and_update_common_device(std::optional<Device>& common_device, const List<std::optional<at::Tensor>>& tensors, at::CheckedFrom methodName, at::CheckedFrom argName) {
78 for (const auto& tensor : tensors) {
79 check_and_update_common_device(common_device, tensor, methodName, argName);
80 }
81 }
82 } // namespace impl
83 } // namespace c10
84