1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorOperators.h>
4 #include <ATen/TensorSubclassLikeUtils.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/complex.h>
11 #include <ATen/ops/corrcoef_native.h>
12 #include <ATen/ops/cov.h>
13 #include <ATen/ops/cov_native.h>
14 #include <ATen/ops/imag.h>
15 #include <ATen/ops/mm.h>
16 #include <ATen/ops/real.h>
17 #include <ATen/ops/scalar_tensor.h>
18 #include <ATen/ops/sqrt.h>
19 #include <ATen/ops/true_divide.h>
20 #endif
21
22 namespace at::native {
23
cov(const Tensor & self,int64_t correction,const std::optional<Tensor> & fweights,const std::optional<Tensor> & aweights)24 Tensor cov(
25 const Tensor& self,
26 int64_t correction,
27 const std::optional<Tensor>& fweights,
28 const std::optional<Tensor>& aweights) {
29 constexpr int64_t OBSERVATIONS_DIM = 1;
30
31 TORCH_CHECK(
32 self.ndimension() <= 2,
33 "cov(): expected input to have two or fewer dimensions but got an input with ",
34 self.ndimension(),
35 " dimensions");
36
37 TORCH_CHECK(
38 self.scalar_type() != kBool,
39 "cov(): bool dtype is not supported for input");
40
41 // View input tensor as 2D (variables, observations)
42 auto in = self.ndimension() < 2 ? self.view({1, -1}) : self;
43 const auto num_observations = in.size(OBSERVATIONS_DIM);
44
45 // The product of frequencies (fweights) and weights (aweights).
46 Tensor w;
47
48 if (fweights.has_value()) {
49 w = fweights.value();
50 TORCH_CHECK(
51 w.ndimension() <= 1,
52 "cov(): expected fweights to have one or fewer dimensions but got fweights with ",
53 w.ndimension(),
54 " dimensions");
55 TORCH_CHECK(
56 at::isIntegralType(w.scalar_type(), false),
57 "cov(): expected fweights to have integral dtype but got fweights with ",
58 w.scalar_type(),
59 " dtype");
60 TORCH_CHECK(
61 w.numel() == num_observations,
62 "cov(): expected fweights to have the same numel as there are observations in the input but got ",
63 w.numel(),
64 " != ",
65 num_observations);
66 TORCH_CHECK(
67 num_observations == 0 || at::is_scalar_tensor_true(w.min().ge(0)),
68 "cov(): fweights cannot be negative");
69 }
70
71 if (aweights.has_value()) {
72 const auto& aw = aweights.value();
73 TORCH_CHECK(
74 aw.ndimension() <= 1,
75 "cov(): expected aweights to have one or fewer dimensions but got aweights with ",
76 aw.ndimension(),
77 " dimensions");
78 TORCH_CHECK(
79 at::isFloatingType(aw.scalar_type()),
80 "cov(): expected aweights to have floating point dtype but got aweights with ",
81 aw.scalar_type(),
82 " dtype");
83 TORCH_CHECK(
84 aw.numel() == num_observations,
85 "cov(): expected aweights to have the same numel as there are observations in the input but got ",
86 aw.numel(),
87 " != ",
88 num_observations);
89 TORCH_CHECK(
90 num_observations == 0 || at::is_scalar_tensor_true(aw.min().ge(0)),
91 "cov(): aweights cannot be negative");
92 w = w.defined() ? w * aw : aw;
93 }
94
95 // Compute a weighted average of the observations
96 const auto w_sum = w.defined()
97 ? w.sum()
98 : at::scalar_tensor(num_observations, in.options().dtype(kLong));
99
100 TORCH_CHECK(
101 !w.defined() || at::is_scalar_tensor_true(w_sum.ne(0)),
102 "cov(): weights sum to zero, can't be normalized");
103
104 const auto avg = (w.defined() ? in * w : in).sum(OBSERVATIONS_DIM) / w_sum;
105
106 // Compute the normalization factor
107 Tensor norm_factor;
108
109 if (w.defined() && aweights.has_value() && correction != 0) {
110 norm_factor = w_sum - correction * (w * aweights.value()).sum() / w_sum;
111 } else {
112 norm_factor = w_sum - correction;
113 }
114
115 if (at::is_scalar_tensor_true(norm_factor.le(0))) {
116 TORCH_WARN("cov(): degrees of freedom is <= 0. Correction should be strictly less than the number of observations.");
117 norm_factor.zero_();
118 }
119
120 // Compute covariance matrix
121 in = in - avg.unsqueeze(1);
122 const auto c = at::mm(in, (w.defined() ? in * w : in).t().conj());
123 return at::true_divide(c, norm_factor).squeeze();
124 }
125
corrcoef(const Tensor & self)126 Tensor corrcoef(const Tensor& self) {
127 TORCH_CHECK(
128 self.ndimension() <= 2,
129 "corrcoef(): expected input to have two or fewer dimensions but got an input with ",
130 self.ndimension(),
131 " dimensions");
132
133 auto c = at::cov(self);
134
135 if (c.ndimension() == 0) {
136 // scalar covariance, return nan if c in {nan, inf, 0}, 1 otherwise
137 return c / c;
138 }
139
140 // normalize covariance
141 const auto d = c.diagonal();
142 const auto stddev = at::sqrt(d.is_complex() ? at::real(d) : d);
143 c = c / stddev.view({-1, 1});
144 c = c / stddev.view({1, -1});
145
146 // due to floating point rounding the values may be not within [-1, 1], so
147 // to improve the result we clip the values just as NumPy does.
148 return c.is_complex()
149 ? at::complex(at::real(c).clip(-1, 1), at::imag(c).clip(-1, 1))
150 : c.clip(-1, 1);
151 }
152
153 } // namespace at::native
154