xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorSubclassLikeUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/List.h>
3 #include <ATen/core/Tensor.h>
4 #include <c10/core/impl/TorchDispatchModeTLS.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/equal.h>
10 #endif
11 
12 namespace at {
13 
14 // Note [Tensor-subclass-like Tensors]
15 // Tensor-subclass-like is defined as:
16 // - a Tensor subclass (via __torch_dispatch__ in Python or extending
17 //   TensorImpl in C++)
18 // - anything else that shares the same perils as Tensor subclasses.
19 //   For example, many Tensor subclasses do not have storage and meta Tensors
20 //   do not have storage either, so meta Tensors belong here.
21 //
22 // We should ensure that PyTorch internals supports Tensor-subclass-like
23 // objects. In particular, Tensor-subclass-like objects struggle with two
24 // classes of operations that are problematic for Tensor subclasses:
25 // 1. Because some Tensor subclasses do not have storage, .item() or
26 //    .data_ptr() calls are not good.
27 // 2. Certain in-place operations can eliminate the typing of the Tensor
28 //    subclass. For example:
29 //    >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input)
30 //    If input is a Tensor subclass, then the above ends up either erroring out
31 //    or returning a regular non-Tensor-subclass Tensor!
32 
33 constexpr auto kFunctorchWrappedTensors = DispatchKeySet(
34     {DispatchKey::FuncTorchGradWrapper,
35      DispatchKey::FuncTorchBatched,
36      DispatchKey::Functionalize});
37 
38 constexpr auto kTensorSubclassLike =
39     kFunctorchWrappedTensors |
40     DispatchKeySet(
41         {// WARNING: DO NOT put combined backend component + functionality keys
42          // here, you will incorrectly always match on the functionality key
43          // no matter the backend component
44          DispatchKey::Batched,
45          DispatchKey::Sparse,
46          DispatchKey::SparseCsr,
47          DispatchKey::Python}) |
48     DispatchKeySet(BackendComponent::MetaBit);
49 
isTensorSubclassLike(const Tensor & tensor)50 inline bool isTensorSubclassLike(const Tensor& tensor) {
51   if (c10::impl::dispatch_mode_enabled())
52     return true;
53   auto key_set = tensor.unsafeGetTensorImpl()->key_set();
54   return !(key_set & kTensorSubclassLike).empty();
55 }
56 
areAnyTensorSubclassLike(TensorList tensors)57 inline bool areAnyTensorSubclassLike(TensorList tensors) {
58   if (c10::impl::dispatch_mode_enabled())
59     return true;
60   return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike);
61 }
62 
areAnyOptionalTensorSubclassLike(const c10::List<std::optional<Tensor>> & tensors)63 inline bool areAnyOptionalTensorSubclassLike(
64     const c10::List<std::optional<Tensor>>& tensors) {
65   if (c10::impl::dispatch_mode_enabled())
66     return true;
67   return std::any_of(
68       tensors.begin(),
69       tensors.end(),
70       [](const std::optional<Tensor>& opt_tensor) {
71         return (
72             opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value()));
73       });
74 }
75 
76 // Helper function to deal testing truthfulness of a scalar tensor
77 // in a Composite Compliant manner.
78 // NOTE: This function expects a scalar tensor of boolean dtype.
79 // Eg.
80 // Non-Composite Compliant Pattern : (t == 0).all().item<bool>()
81 // Composite Compliant Patter : is_salar_tensor_true((t == 0).all())
is_scalar_tensor_true(const Tensor & t)82 inline bool is_scalar_tensor_true(const Tensor& t) {
83   TORCH_INTERNAL_ASSERT(t.dim() == 0)
84   TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool)
85   return at::equal(t, t.new_ones({}, t.options()));
86 }
87 
88 } // namespace at
89