xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Correlation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorOperators.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/complex.h>
11 #include <ATen/ops/corrcoef_native.h>
12 #include <ATen/ops/cov.h>
13 #include <ATen/ops/cov_native.h>
14 #include <ATen/ops/imag.h>
15 #include <ATen/ops/mm.h>
16 #include <ATen/ops/real.h>
17 #include <ATen/ops/scalar_tensor.h>
18 #include <ATen/ops/sqrt.h>
19 #include <ATen/ops/true_divide.h>
20 #endif
21 
22 namespace at::native {
23 
cov(const Tensor & self,int64_t correction,const std::optional<Tensor> & fweights,const std::optional<Tensor> & aweights)24 Tensor cov(
25     const Tensor& self,
26     int64_t correction,
27     const std::optional<Tensor>& fweights,
28     const std::optional<Tensor>& aweights) {
29   constexpr int64_t OBSERVATIONS_DIM = 1;
30 
31   TORCH_CHECK(
32       self.ndimension() <= 2,
33       "cov(): expected input to have two or fewer dimensions but got an input with ",
34       self.ndimension(),
35       " dimensions");
36 
37   TORCH_CHECK(
38       self.scalar_type() != kBool,
39       "cov(): bool dtype is not supported for input");
40 
41   // View input tensor as 2D (variables, observations)
42   auto in = self.ndimension() < 2 ? self.view({1, -1}) : self;
43   const auto num_observations = in.size(OBSERVATIONS_DIM);
44 
45   // The product of frequencies (fweights) and weights (aweights).
46   Tensor w;
47 
48   if (fweights.has_value()) {
49     w = fweights.value();
50     TORCH_CHECK(
51         w.ndimension() <= 1,
52         "cov(): expected fweights to have one or fewer dimensions but got fweights with ",
53         w.ndimension(),
54         " dimensions");
55     TORCH_CHECK(
56         at::isIntegralType(w.scalar_type(), false),
57         "cov(): expected fweights to have integral dtype but got fweights with ",
58         w.scalar_type(),
59         " dtype");
60     TORCH_CHECK(
61         w.numel() == num_observations,
62         "cov(): expected fweights to have the same numel as there are observations in the input but got ",
63         w.numel(),
64         " != ",
65         num_observations);
66     TORCH_CHECK(
67         num_observations == 0 || at::is_scalar_tensor_true(w.min().ge(0)),
68         "cov(): fweights cannot be negative");
69   }
70 
71   if (aweights.has_value()) {
72     const auto& aw = aweights.value();
73     TORCH_CHECK(
74         aw.ndimension() <= 1,
75         "cov(): expected aweights to have one or fewer dimensions but got aweights with ",
76         aw.ndimension(),
77         " dimensions");
78     TORCH_CHECK(
79         at::isFloatingType(aw.scalar_type()),
80         "cov(): expected aweights to have floating point dtype but got aweights with ",
81         aw.scalar_type(),
82         " dtype");
83     TORCH_CHECK(
84         aw.numel() == num_observations,
85         "cov(): expected aweights to have the same numel as there are observations in the input but got ",
86         aw.numel(),
87         " != ",
88         num_observations);
89     TORCH_CHECK(
90         num_observations == 0 || at::is_scalar_tensor_true(aw.min().ge(0)),
91         "cov(): aweights cannot be negative");
92     w = w.defined() ? w * aw : aw;
93   }
94 
95   // Compute a weighted average of the observations
96   const auto w_sum = w.defined()
97       ? w.sum()
98       : at::scalar_tensor(num_observations, in.options().dtype(kLong));
99 
100   TORCH_CHECK(
101       !w.defined() || at::is_scalar_tensor_true(w_sum.ne(0)),
102       "cov(): weights sum to zero, can't be normalized");
103 
104   const auto avg = (w.defined() ? in * w : in).sum(OBSERVATIONS_DIM) / w_sum;
105 
106   // Compute the normalization factor
107   Tensor norm_factor;
108 
109   if (w.defined() && aweights.has_value() && correction != 0) {
110     norm_factor = w_sum - correction * (w * aweights.value()).sum() / w_sum;
111   } else {
112     norm_factor = w_sum - correction;
113   }
114 
115   if (at::is_scalar_tensor_true(norm_factor.le(0))) {
116     TORCH_WARN("cov(): degrees of freedom is <= 0. Correction should be strictly less than the number of observations.");
117     norm_factor.zero_();
118   }
119 
120   // Compute covariance matrix
121   in = in - avg.unsqueeze(1);
122   const auto c = at::mm(in, (w.defined() ? in * w : in).t().conj());
123   return at::true_divide(c, norm_factor).squeeze();
124 }
125 
corrcoef(const Tensor & self)126 Tensor corrcoef(const Tensor& self) {
127   TORCH_CHECK(
128       self.ndimension() <= 2,
129       "corrcoef(): expected input to have two or fewer dimensions but got an input with ",
130       self.ndimension(),
131       " dimensions");
132 
133   auto c = at::cov(self);
134 
135   if (c.ndimension() == 0) {
136     // scalar covariance, return nan if c in {nan, inf, 0}, 1 otherwise
137     return c / c;
138   }
139 
140   // normalize covariance
141   const auto d = c.diagonal();
142   const auto stddev = at::sqrt(d.is_complex() ? at::real(d) : d);
143   c = c / stddev.view({-1, 1});
144   c = c / stddev.view({1, -1});
145 
146   // due to floating point rounding the values may be not within [-1, 1], so
147   // to improve the result we clip the values just as NumPy does.
148   return c.is_complex()
149       ? at::complex(at::real(c).clip(-1, 1), at::imag(c).clip(-1, 1))
150       : c.clip(-1, 1);
151 }
152 
153 } // namespace at::native
154