xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Onehot.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 
4 #ifndef AT_PER_OPERATOR_HEADERS
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #else
8 #include <ATen/ops/arange.h>
9 #include <ATen/ops/empty.h>
10 #include <ATen/ops/eq.h>
11 #include <ATen/ops/one_hot_native.h>
12 #include <ATen/ops/zeros.h>
13 #endif
14 
15 namespace at::native {
16 
one_hot(const Tensor & self,int64_t num_classes)17 Tensor one_hot(const Tensor &self, int64_t num_classes) {
18     TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor of type LongTensor.");
19 
20     // using meta bit test to catch Fake Tensor as well until __torch_function__
21     if (self.key_set().has_all(DispatchKeySet(BackendComponent::MetaBit)) ||
22             self.key_set().has_all(DispatchKeySet(DispatchKey::Python))) {
23         // functional version that torch.compiles better and works with dynamic shapes
24         if (num_classes == -1) {
25           num_classes = self.max().item().toLong() + 1;
26         }
27         at::Tensor index = at::arange(num_classes, self.options());
28         return at::eq(self.unsqueeze(-1), index).to(kLong);
29     }
30 
31     auto shape = self.sizes().vec();
32 
33     // empty tensor could be converted to one hot representation,
34     // but shape inference is not possible.
35     if (self.numel() == 0) {
36         if (num_classes <= 0) {
37             AT_ERROR("Can not infer total number of classes from empty tensor.");
38         } else {
39             shape.push_back(num_classes);
40             return at::empty(shape, self.options());
41         }
42     }
43 
44     // non-empty tensor
45     if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS &&
46         self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) {
47       // for cuda, rely on device assert thrown by scatter
48       TORCH_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
49     }
50     if (num_classes == -1) {
51         num_classes = self.max().item().toLong() + 1;
52     } else {
53         if (self.device().type() != at::kCUDA && self.device().type() != at::kMPS &&
54             self.device().type() != at::kPrivateUse1 && self.device().type() != at::kXLA) {
55           // rely on device asserts from scatter to avoid sync here
56           TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
57         } else {
58             //for cuda, assert that num_classes is at least 1
59             TORCH_CHECK(num_classes >= 1, "num_classes should be positive");
60         }
61     }
62 
63     shape.push_back(num_classes);
64     Tensor ret = at::zeros(shape, self.options());
65     ret.scatter_(-1, self.unsqueeze(-1), 1);
66     return ret;
67 }
68 
69 } // namespace at::native
70