1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Distance.h>
3
4 #include <algorithm>
5
6 #include <ATen/core/Tensor.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/cpu/vec/functional.h>
11 #include <c10/util/irange.h>
12
13 namespace at::native {
14 namespace {
15
16 template<typename scalar_t>
17 struct Dist {
18 using Vec = vec::Vectorized<scalar_t>;
19
20 // Depending on the value of the pnorm, there are specific implementations
21 // that are much faster than std::pow(std::abs(a - b), p), but have the same
22 // standard loop code for how to process the input vector. To reuse the main
23 // outside loop while still guaranteeing that the compiler inlines every
24 // different function on p, we break the inner norm logic into structs with
25 // static functions that represent what's done differently, and template the
26 // outer loop on those structs.
27 //
28 // The four functions are:
29 // map : This tells how to modify (a - b) to form the component that
30 // gets summed.
31 // red : This tells how to sum the result of map up. This is
32 // separate because the inf norm actually uses max instead of
33 // sum.
34 // finish : This tells what to do with the aggregated value to compute
35 // the norm. Generally this is the result of val ^ (1 / p).
36 // backward : This is the gradient for that norm. Arguments are pretty
37 // self explanitory.
38 //
39 // There are a few cases where these aren't used. The 0 norm has no backward,
40 // because it's always 0, so that's shortcircuited earlier. There's a special
41 // implementation of the general backward pass when p is less than two, so
42 // there's a struct with only a backward pass for this case.
43
44 // TODO This is an inefficient way to compite sign, and can be much faster
45 // using native SSE instructions that should be added to Vectorized.
signat::native::__anon554798c00111::Dist46 static inline Vec sign(Vec val) {
47 return vec::minimum(vec::maximum(Vec(0), val.ceil()), Vec(1)) +
48 vec::minimum(vec::maximum(Vec(-1), val.floor()), Vec(0));
49 }
50
absat::native::__anon554798c00111::Dist51 static inline Vec abs(Vec val) {
52 return val.abs();
53 }
54
absat::native::__anon554798c00111::Dist55 static inline scalar_t abs(scalar_t val) {
56 return std::abs(val);
57 }
58
ceilat::native::__anon554798c00111::Dist59 static inline Vec ceil(Vec val) {
60 return val.ceil();
61 }
62
ceilat::native::__anon554798c00111::Dist63 static inline scalar_t ceil(scalar_t val) {
64 return std::ceil(val);
65 }
66
minat::native::__anon554798c00111::Dist67 static inline Vec min(Vec val, scalar_t other) {
68 return vec::minimum(val, Vec(other));
69 }
70
minat::native::__anon554798c00111::Dist71 static inline scalar_t min(scalar_t val, scalar_t other) {
72 return std::min(val, other);
73 }
74
maxat::native::__anon554798c00111::Dist75 static inline Vec max(Vec val, Vec other) {
76 return vec::maximum(val, other);
77 }
78
maxat::native::__anon554798c00111::Dist79 static inline scalar_t max(scalar_t val, scalar_t other) {
80 return std::max(val, other);
81 }
82
powat::native::__anon554798c00111::Dist83 static inline Vec pow(Vec val, Vec p) {
84 return val.pow(p);
85 }
86
powat::native::__anon554798c00111::Dist87 static inline scalar_t pow(scalar_t val, scalar_t p) {
88 return std::pow(val, p);
89 }
90
91 // Zero norm
92 template<typename data_t>
93 struct zdist_calc {
mapat::native::__anon554798c00111::Dist::zdist_calc94 static inline data_t map(const data_t& diff, const data_t& p) { return min(ceil(abs(diff)), 1); }
redat::native::__anon554798c00111::Dist::zdist_calc95 static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::zdist_calc96 static inline scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
97 };
98
99 // One norm
100 template<typename data_t>
101 struct odist_calc {
mapat::native::__anon554798c00111::Dist::odist_calc102 static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
redat::native::__anon554798c00111::Dist::odist_calc103 static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::odist_calc104 static inline scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
backwardat::native::__anon554798c00111::Dist::odist_calc105 static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t /*dist*/, const Vec& /*p*/) { return Vec(grad) * sign(diff); }
106 };
107
108 // Special general pnorm derivative if p is less than two
109 struct lttdist_calc {
backwardat::native::__anon554798c00111::Dist::lttdist_calc110 static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) {
111 Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)));
112 result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1)));
113 return result;
114 }
115 };
116
117 // Two norm
118 template<typename data_t>
119 struct tdist_calc {
120 // TODO This can probably use fused add multiply to get better perf
mapat::native::__anon554798c00111::Dist::tdist_calc121 static inline data_t map(const data_t& diff, const data_t& p) { return diff * diff; }
redat::native::__anon554798c00111::Dist::tdist_calc122 static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::tdist_calc123 static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::sqrt(agg); }
backwardat::native::__anon554798c00111::Dist::tdist_calc124 static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : Vec(grad) * diff / Vec(dist); }
125 };
126
127 // General p norm
128 template<typename data_t>
129 struct pdist_calc {
mapat::native::__anon554798c00111::Dist::pdist_calc130 static inline data_t map(const data_t& diff, const data_t& p) { return pow(diff, p); }
redat::native::__anon554798c00111::Dist::pdist_calc131 static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::pdist_calc132 static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, 1.0 / p); }
backwardat::native::__anon554798c00111::Dist::pdist_calc133 static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : diff * diff.abs().pow(p - Vec(2)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
134 };
135
136 // Inf norm
137 template<typename data_t>
138 struct idist_calc {
mapat::native::__anon554798c00111::Dist::idist_calc139 static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
redat::native::__anon554798c00111::Dist::idist_calc140 static inline data_t red(const data_t& agg, const data_t& up) { return max(agg, up); }
finishat::native::__anon554798c00111::Dist::idist_calc141 static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
142 // TODO This backward pass uses a very complext expression to compute (diff
143 // == dist) that could be much faster if using SSE instructions.
backwardat::native::__anon554798c00111::Dist::idist_calc144 static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff) * (Vec(1) - vec::minimum(Vec(1), (diff.abs() - Vec(dist)).abs().ceil())); }
145 };
146
147 template <typename F>
run_parallel_pdistat::native::__anon554798c00111::Dist148 static void run_parallel_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
149 const scalar_t * const self_start = self.const_data_ptr<scalar_t>();
150 const scalar_t * const self_end = self_start + self.numel();
151 int64_t n = self.size(0);
152 int64_t m = self.size(1);
153
154 scalar_t * const res_start = result.data_ptr<scalar_t>();
155 int64_t combs = result.numel(); // n * (n - 1) / 2
156
157 // We conceptually iterate over tuples of (i, j, k) where i is the first
158 // vector from the input, j is the second, and k is the result index. This
159 // parallelizes over the range of k and infers what i and j are from the
160 // value of k.
161 parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [p, self_start, self_end, n, m, res_start](int64_t k, int64_t end) {
162 const Vec pvec(p);
163 double n2 = n - .5;
164 // The -1 accounts for floating point truncation issues
165 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
166 int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
167 int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
168
169 const scalar_t * self_i = self_start + i * m;
170 const scalar_t * self_j = self_start + j * m;
171 scalar_t * res = res_start + k;
172 const scalar_t * const res_end = res_start + end;
173
174 while (res != res_end) {
175 *res = F::finish(vec::map2_reduce_all<scalar_t>(
176 [&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); },
177 F::red, self_i, self_j, m), p);
178
179 res += 1;
180 self_j += m;
181 if (self_j == self_end) {
182 self_i += m;
183 self_j = self_i + m;
184 }
185 }
186 });
187 }
188
189 // Assumes self is nonempty, contiguous, and 2D
apply_pdistat::native::__anon554798c00111::Dist190 static void apply_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
191 if (p == 0.0) {
192 run_parallel_pdist<zdist_calc<Vec>>(result, self, p);
193 } else if (p == 1.0) {
194 run_parallel_pdist<odist_calc<Vec>>(result, self, p);
195 } else if (p == 2.0) {
196 run_parallel_pdist<tdist_calc<Vec>>(result, self, p);
197 } else if (std::isinf(p)) {
198 run_parallel_pdist<idist_calc<Vec>>(result, self, p);
199 } else {
200 run_parallel_pdist<pdist_calc<Vec>>(result, self, p);
201 }
202 }
203
204 template <typename F>
run_parallel_cdistat::native::__anon554798c00111::Dist205 static void run_parallel_cdist(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
206 const scalar_t * const t1_start = t1.const_data_ptr<scalar_t>();
207 const scalar_t * const t2_start = t2.const_data_ptr<scalar_t>();
208 int64_t d = t1.size(0);
209 int64_t r1 = t1.size(-2);
210 int64_t r2 = t2.size(-2);
211 int64_t m = t1.size(-1);
212
213 scalar_t * const res_start = result.data_ptr<scalar_t>();
214 int64_t combs = r1 * r2;
215 int64_t size1 = r1 * m;
216 int64_t size2 = r2 * m;
217
218 parallel_for(0, combs * d, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
219 scalar_t * res = res_start + start;
220 const scalar_t * const res_end = res_start + end;
221 int64_t l = start / combs;
222 int64_t k = start % combs;
223 int64_t i = k / r2;
224 int64_t j = k % r2;
225 i = i * m;
226 j = j * m;
227
228 while (res != res_end) {
229 const scalar_t * self_i = t1_start + size1 * l + i;
230 const scalar_t * self_j = t2_start + size2 * l + j;
231
232 scalar_t agg = 0;
233 for (const auto x : c10::irange(m)) {
234 scalar_t a = *(self_i + x);
235 scalar_t b = *(self_j + x);
236 agg = F::red(agg, F::map(std::abs(a-b), p));
237 }
238 *res = F::finish(agg, p);
239
240 res += 1;
241 j += m;
242 if (j == size2) {
243 j = 0;
244 i += m;
245 if (i == size1) {
246 i = 0;
247 l += 1;
248 }
249 }
250 }
251 });
252 }
253
apply_cdistat::native::__anon554798c00111::Dist254 static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) {
255 if (p == 0.0) {
256 run_parallel_cdist<zdist_calc<scalar_t>>(result, x1, x2, p);
257 } else if (p == 1.0) {
258 run_parallel_cdist<odist_calc<scalar_t>>(result, x1, x2, p);
259 } else if (p == 2.0) {
260 run_parallel_cdist<tdist_calc<scalar_t>>(result, x1, x2, p);
261 } else if (std::isinf(p)) {
262 run_parallel_cdist<idist_calc<scalar_t>>(result, x1, x2, p);
263 } else {
264 run_parallel_cdist<pdist_calc<scalar_t>>(result, x1, x2, p);
265 }
266 }
267
268 // This does a backward pass down a Vec column of the input
269 template <typename F>
backward_down_column_pdistat::native::__anon554798c00111::Dist270 inline static void backward_down_column_pdist(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
271 for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
272
273 const Vec self_vec_i = Vec::loadu(self_i, count);
274 Vec res_vec_i = Vec::loadu(res_i, count);
275
276 const scalar_t * self_j = self_i + m;
277 scalar_t * res_j = res_i + m;
278 for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
279 const Vec self_vec_j = Vec::loadu(self_j, count);
280 Vec res_vec_j = Vec::loadu(res_j, count);
281
282 Vec res = F::backward(self_vec_i - self_vec_j, *grad_k, *dist_k, pvec);
283 res_vec_i = res_vec_i + res;
284 res_vec_j = res_vec_j - res;
285
286 res_vec_j.store(res_j, count);
287 }
288
289 res_vec_i.store(res_i, count);
290 }
291 }
292
293 template <typename F>
run_backward_parallel_pdistat::native::__anon554798c00111::Dist294 static void run_backward_parallel_pdist(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
295 const int64_t n = self.size(0);
296 const int64_t m = self.size(1);
297 const int64_t gs = grad.stride(0);
298
299 const scalar_t * const grad_start = grad.const_data_ptr<scalar_t>();
300 const scalar_t * const dist_start = dist.const_data_ptr<scalar_t>();
301 const scalar_t * const self_start = self.const_data_ptr<scalar_t>();
302 scalar_t * const res_start = result.data_ptr<scalar_t>();
303
304 // The only way to parallelize and avoid locking requires parallelizing
305 // over the columns of the input, i.e. we compute the gradient for the
306 // first section of each vector independently of the second section, etc.
307 at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [p, n, m, gs, grad_start, dist_start, self_start, res_start](int64_t l, int64_t end) {
308 const Vec pvec(p);
309
310 const scalar_t * self_l = self_start + l * Vec::size();
311 scalar_t * res_l = res_start + l * Vec::size();
312
313 for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
314 backward_down_column_pdist<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
315 }
316 });
317 const int64_t remainder = m % Vec::size();
318 if (remainder) {
319 backward_down_column_pdist<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), n, m, gs, remainder);
320 }
321 }
322
323 // Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
apply_backward_pdistat::native::__anon554798c00111::Dist324 static void apply_backward_pdist(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
325 result.fill_(0);
326 if (p == 0.0) {
327 } else if (p == 1.0) {
328 run_backward_parallel_pdist<odist_calc<Vec>>(result, grad, self, p, dist);
329 } else if (p < 2.0) {
330 run_backward_parallel_pdist<lttdist_calc>(result, grad, self, p, dist);
331 } else if (p == 2.0) {
332 run_backward_parallel_pdist<tdist_calc<Vec>>(result, grad, self, p, dist);
333 } else if (std::isinf(p)) {
334 run_backward_parallel_pdist<idist_calc<Vec>>(result, grad, self, p, dist);
335 } else {
336 run_backward_parallel_pdist<pdist_calc<Vec>>(result, grad, self, p, dist);
337 }
338 }
339
apply_backward_cdistat::native::__anon554798c00111::Dist340 static void apply_backward_cdist(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
341 result.fill_(0);
342 if (p == 0.0) {
343 } else if (p == 1.0) {
344 run_backward_parallel_cdist<odist_calc<Vec>>(result, grad, x1, x2, p, dist);
345 } else if (p < 2.0) {
346 run_backward_parallel_cdist<lttdist_calc>(result, grad, x1, x2, p, dist);
347 } else if (p == 2.0) {
348 run_backward_parallel_cdist<tdist_calc<Vec>>(result, grad, x1, x2, p, dist);
349 } else if (std::isinf(p)) {
350 run_backward_parallel_cdist<idist_calc<Vec>>(result, grad, x1, x2, p, dist);
351 } else {
352 run_backward_parallel_cdist<pdist_calc<Vec>>(result, grad, x1, x2, p, dist);
353 }
354 }
355
356
357 template <typename F>
run_backward_parallel_cdistat::native::__anon554798c00111::Dist358 static void run_backward_parallel_cdist(Tensor& result, const Tensor & grad, const Tensor & t1, const Tensor & t2, const scalar_t p, const Tensor& dist) {
359 const int64_t r1 = t1.size(-2);
360 const int64_t r2 = t2.size(-2);
361 const int64_t m = t1.size(-1);
362 const int64_t d = result.size(0);
363 const int64_t l1_size = r1 * m;
364 const int64_t l2_size = r2 * m;
365 //current implementation supports only tensor that can be collapsed to 1D. However, to avoid checking if grad satisfies this assumption,
366 //we call .contiguous() on grad before backward, thus stride is guaranteed to be 1
367 //don't use grad.stride(-1), because if last dimension is 1, stride can be bogus.
368 const int64_t gs = 1;
369
370 const scalar_t * const grad_start = grad.const_data_ptr<scalar_t>();
371 const scalar_t * const dist_start = dist.const_data_ptr<scalar_t>();
372 const scalar_t * const t1_start = t1.const_data_ptr<scalar_t>();
373 const scalar_t * const t2_start = t2.const_data_ptr<scalar_t>();
374 scalar_t * const res_start = result.data_ptr<scalar_t>();
375
376 at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (16 * r1), [=](int64_t l, int64_t end) {
377 const Vec pvec(p);
378
379 const scalar_t * i = t1_start + l * Vec::size();
380 const scalar_t * j = t2_start + l * Vec::size();
381 scalar_t * res_l = res_start + l * Vec::size();
382
383 for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; i += Vec::size(), j += Vec::size(), res_l += Vec::size()) {
384 backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, d, gs, l1_size, l2_size);
385 }
386 });
387 const int64_t remainder = m % Vec::size();
388 if (remainder) {
389 backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, d, gs, l1_size, l2_size, remainder);
390 }
391 }
392
393 template <typename F>
backward_down_column_cdistat::native::__anon554798c00111::Dist394 inline static void backward_down_column_cdist(const scalar_t * t1, const scalar_t * t2, scalar_t * res, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t r1, int64_t r2, int64_t m, int64_t d, int64_t gs, int64_t l1_size, int64_t l2_size, int64_t count = Vec::size()) {
395 const scalar_t * t1_end = t1 + l1_size;
396 const scalar_t * t2_end = t2 + l2_size;
397
398 for (const auto l C10_UNUSED : c10::irange(d)) {
399 for (; t1 != t1_end; t1 += m, res += m) {
400 const Vec vec_t1 = Vec::loadu(t1, count);
401 Vec res_vec = Vec::loadu(res, count);
402
403 for (const scalar_t * t2_curr = t2; t2_curr != t2_end; t2_curr += m, grad_k += gs, dist_k += 1) {
404 const Vec vec_t2 = Vec::loadu(t2_curr, count);
405 Vec res = F::backward(vec_t1 - vec_t2, *grad_k, *dist_k, pvec);
406 res_vec = res_vec + res;
407 }
408
409 res_vec.store(res, count);
410 }
411 t1_end += l1_size;
412 t2_end += l2_size;
413 t2 += l2_size;
414 }
415 }
416
417 };
418
pdist_forward_kernel_impl(Tensor & result,const Tensor & self,const double p)419 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
420 AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] {
421 Dist<scalar_t>::apply_pdist(result, self, p);
422 });
423 }
424
pdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & self,const double p,const Tensor & dist)425 static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
426 AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] {
427 Dist<scalar_t>::apply_backward_pdist(result, grad, self, p, dist);
428 });
429 }
430
cdist_kernel_impl(Tensor & result,const Tensor & x1,const Tensor & x2,const double p)431 static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
432 AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] {
433 Dist<scalar_t>::apply_cdist(result, x1, x2, p);
434 });
435 }
436
cdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & x1,const Tensor & x2,const double p,const Tensor & dist)437 static void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
438 AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_backward", [&] {
439 Dist<scalar_t>::apply_backward_cdist(result, grad, x1, x2, p, dist);
440 });
441 }
442
443
444 } // anonymous namespace
445
446 REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
447 REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
448 REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
449 REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);
450
451 } // namespace at::native
452