xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Distance.h>
3 
4 #include <algorithm>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/Parallel.h>
9 #include <ATen/TensorIterator.h>
10 #include <ATen/cpu/vec/functional.h>
11 #include <c10/util/irange.h>
12 
13 namespace at::native {
14 namespace {
15 
16 template<typename scalar_t>
17 struct Dist {
18   using Vec = vec::Vectorized<scalar_t>;
19 
20   // Depending on the value of the pnorm, there are specific implementations
21   // that are much faster than std::pow(std::abs(a - b), p), but have the same
22   // standard loop code for how to process the input vector. To reuse the main
23   // outside loop while still guaranteeing that the compiler inlines every
24   // different function on p, we break the inner norm logic into structs with
25   // static functions that represent what's done differently, and template the
26   // outer loop on those structs.
27   //
28   // The four functions are:
29   //     map :      This tells how to modify (a - b) to form the component that
30   //                gets summed.
31   //     red :      This tells how to sum the result of map up. This is
32   //                separate because the inf norm actually uses max instead of
33   //                sum.
34   //     finish :   This tells what to do with the aggregated value to compute
35   //                the norm. Generally this is the result of val ^ (1 / p).
36   //     backward : This is the gradient for that norm. Arguments are pretty
37   //                self explanitory.
38   //
39   // There are a few cases where these aren't used. The 0 norm has no backward,
40   // because it's always 0, so that's shortcircuited earlier. There's a special
41   // implementation of the general backward pass when p is less than two, so
42   // there's a struct with only a backward pass for this case.
43 
44   // TODO This is an inefficient way to compite sign, and can be much faster
45   // using native SSE instructions that should be added to Vectorized.
signat::native::__anon554798c00111::Dist46   static inline Vec sign(Vec val) {
47     return vec::minimum(vec::maximum(Vec(0), val.ceil()), Vec(1)) +
48       vec::minimum(vec::maximum(Vec(-1), val.floor()), Vec(0));
49   }
50 
absat::native::__anon554798c00111::Dist51   static inline Vec abs(Vec val) {
52     return val.abs();
53   }
54 
absat::native::__anon554798c00111::Dist55   static inline scalar_t abs(scalar_t val) {
56     return std::abs(val);
57   }
58 
ceilat::native::__anon554798c00111::Dist59   static inline Vec ceil(Vec val) {
60     return val.ceil();
61   }
62 
ceilat::native::__anon554798c00111::Dist63   static inline scalar_t ceil(scalar_t val) {
64     return std::ceil(val);
65   }
66 
minat::native::__anon554798c00111::Dist67   static inline Vec min(Vec val, scalar_t other) {
68     return vec::minimum(val, Vec(other));
69   }
70 
minat::native::__anon554798c00111::Dist71   static inline scalar_t min(scalar_t val, scalar_t other) {
72     return std::min(val, other);
73   }
74 
maxat::native::__anon554798c00111::Dist75   static inline Vec max(Vec val, Vec other) {
76     return vec::maximum(val, other);
77   }
78 
maxat::native::__anon554798c00111::Dist79   static inline scalar_t max(scalar_t val, scalar_t other) {
80     return std::max(val, other);
81   }
82 
powat::native::__anon554798c00111::Dist83   static inline Vec pow(Vec val, Vec p) {
84     return val.pow(p);
85   }
86 
powat::native::__anon554798c00111::Dist87   static inline scalar_t pow(scalar_t val, scalar_t p) {
88     return std::pow(val, p);
89   }
90 
91   // Zero norm
92   template<typename data_t>
93   struct zdist_calc {
mapat::native::__anon554798c00111::Dist::zdist_calc94     static inline data_t map(const data_t& diff, const data_t& p) { return min(ceil(abs(diff)), 1); }
redat::native::__anon554798c00111::Dist::zdist_calc95     static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::zdist_calc96     static inline scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
97   };
98 
99   // One norm
100   template<typename data_t>
101   struct odist_calc {
mapat::native::__anon554798c00111::Dist::odist_calc102     static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
redat::native::__anon554798c00111::Dist::odist_calc103     static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::odist_calc104     static inline scalar_t finish(const scalar_t agg, const scalar_t /*p*/) { return agg; }
backwardat::native::__anon554798c00111::Dist::odist_calc105     static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t /*dist*/, const Vec& /*p*/) { return Vec(grad) * sign(diff); }
106   };
107 
108   // Special general pnorm derivative if p is less than two
109   struct lttdist_calc {
backwardat::native::__anon554798c00111::Dist::lttdist_calc110     static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) {
111       Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)));
112       result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1)));
113       return result;
114     }
115   };
116 
117   // Two norm
118   template<typename data_t>
119   struct tdist_calc {
120     // TODO This can probably use fused add multiply to get better perf
mapat::native::__anon554798c00111::Dist::tdist_calc121     static inline data_t map(const data_t& diff, const data_t& p) { return diff * diff; }
redat::native::__anon554798c00111::Dist::tdist_calc122     static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::tdist_calc123     static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::sqrt(agg); }
backwardat::native::__anon554798c00111::Dist::tdist_calc124     static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : Vec(grad) * diff / Vec(dist); }
125   };
126 
127   // General p norm
128   template<typename data_t>
129   struct pdist_calc {
mapat::native::__anon554798c00111::Dist::pdist_calc130     static inline data_t map(const data_t& diff, const data_t& p) { return pow(diff, p); }
redat::native::__anon554798c00111::Dist::pdist_calc131     static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
finishat::native::__anon554798c00111::Dist::pdist_calc132     static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, 1.0 / p); }
backwardat::native::__anon554798c00111::Dist::pdist_calc133     static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : diff * diff.abs().pow(p - Vec(2)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
134   };
135 
136   // Inf norm
137   template<typename data_t>
138   struct idist_calc {
mapat::native::__anon554798c00111::Dist::idist_calc139     static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
redat::native::__anon554798c00111::Dist::idist_calc140     static inline data_t red(const data_t& agg, const data_t& up) { return max(agg, up); }
finishat::native::__anon554798c00111::Dist::idist_calc141     static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
142     // TODO This backward pass uses a very complext expression to compute (diff
143     // == dist) that could be much faster if using SSE instructions.
backwardat::native::__anon554798c00111::Dist::idist_calc144     static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff) * (Vec(1) - vec::minimum(Vec(1), (diff.abs() - Vec(dist)).abs().ceil())); }
145   };
146 
147   template <typename F>
run_parallel_pdistat::native::__anon554798c00111::Dist148   static void run_parallel_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
149     const scalar_t * const self_start = self.const_data_ptr<scalar_t>();
150     const scalar_t * const self_end = self_start + self.numel();
151     int64_t n = self.size(0);
152     int64_t m = self.size(1);
153 
154     scalar_t * const res_start = result.data_ptr<scalar_t>();
155     int64_t combs = result.numel(); // n * (n - 1) / 2
156 
157     // We conceptually iterate over tuples of (i, j, k) where i is the first
158     // vector from the input, j is the second, and k is the result index. This
159     // parallelizes over the range of k and infers what i and j are from the
160     // value of k.
161     parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [p, self_start, self_end, n, m, res_start](int64_t k, int64_t end) {
162       const Vec pvec(p);
163       double n2 = n - .5;
164       // The -1 accounts for floating point truncation issues
165       // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
166       int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
167       int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
168 
169       const scalar_t * self_i = self_start + i * m;
170       const scalar_t * self_j = self_start + j * m;
171       scalar_t * res = res_start + k;
172       const scalar_t * const res_end = res_start + end;
173 
174       while (res != res_end) {
175         *res = F::finish(vec::map2_reduce_all<scalar_t>(
176           [&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); },
177           F::red, self_i, self_j, m), p);
178 
179         res += 1;
180         self_j += m;
181         if (self_j == self_end) {
182           self_i += m;
183           self_j = self_i + m;
184         }
185       }
186     });
187   }
188 
189   // Assumes self is nonempty, contiguous, and 2D
apply_pdistat::native::__anon554798c00111::Dist190   static void apply_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
191     if (p == 0.0) {
192       run_parallel_pdist<zdist_calc<Vec>>(result, self, p);
193     } else if (p == 1.0) {
194       run_parallel_pdist<odist_calc<Vec>>(result, self, p);
195     } else if (p == 2.0) {
196       run_parallel_pdist<tdist_calc<Vec>>(result, self, p);
197     } else if (std::isinf(p)) {
198       run_parallel_pdist<idist_calc<Vec>>(result, self, p);
199     } else {
200       run_parallel_pdist<pdist_calc<Vec>>(result, self, p);
201     }
202   }
203 
204   template <typename F>
run_parallel_cdistat::native::__anon554798c00111::Dist205   static void run_parallel_cdist(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
206     const scalar_t * const t1_start = t1.const_data_ptr<scalar_t>();
207     const scalar_t * const t2_start = t2.const_data_ptr<scalar_t>();
208     int64_t d = t1.size(0);
209     int64_t r1 = t1.size(-2);
210     int64_t r2 = t2.size(-2);
211     int64_t m = t1.size(-1);
212 
213     scalar_t * const res_start = result.data_ptr<scalar_t>();
214     int64_t combs = r1 * r2;
215     int64_t size1 = r1 * m;
216     int64_t size2 = r2 * m;
217 
218     parallel_for(0, combs * d, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
219       scalar_t * res = res_start + start;
220       const scalar_t * const res_end = res_start + end;
221       int64_t l = start / combs;
222       int64_t k = start % combs;
223       int64_t i = k / r2;
224       int64_t j = k % r2;
225       i = i * m;
226       j = j * m;
227 
228       while (res != res_end) {
229         const scalar_t * self_i = t1_start + size1 * l + i;
230         const scalar_t * self_j = t2_start + size2 * l + j;
231 
232         scalar_t agg = 0;
233         for (const auto x : c10::irange(m)) {
234           scalar_t a = *(self_i + x);
235           scalar_t b = *(self_j + x);
236           agg = F::red(agg, F::map(std::abs(a-b), p));
237         }
238         *res = F::finish(agg, p);
239 
240         res += 1;
241         j += m;
242         if (j == size2) {
243           j = 0;
244           i += m;
245           if (i == size1) {
246             i = 0;
247             l += 1;
248           }
249         }
250       }
251     });
252   }
253 
apply_cdistat::native::__anon554798c00111::Dist254   static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) {
255     if (p == 0.0) {
256       run_parallel_cdist<zdist_calc<scalar_t>>(result, x1, x2, p);
257     } else if (p == 1.0) {
258       run_parallel_cdist<odist_calc<scalar_t>>(result, x1, x2, p);
259     } else if (p == 2.0) {
260       run_parallel_cdist<tdist_calc<scalar_t>>(result, x1, x2, p);
261     } else if (std::isinf(p)) {
262       run_parallel_cdist<idist_calc<scalar_t>>(result, x1, x2, p);
263     } else {
264       run_parallel_cdist<pdist_calc<scalar_t>>(result, x1, x2, p);
265     }
266   }
267 
268   // This does a backward pass down a Vec column of the input
269   template <typename F>
backward_down_column_pdistat::native::__anon554798c00111::Dist270   inline static void backward_down_column_pdist(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
271     for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
272 
273       const Vec self_vec_i = Vec::loadu(self_i, count);
274       Vec res_vec_i = Vec::loadu(res_i, count);
275 
276       const scalar_t * self_j = self_i + m;
277       scalar_t * res_j = res_i + m;
278       for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
279         const Vec self_vec_j = Vec::loadu(self_j, count);
280         Vec res_vec_j = Vec::loadu(res_j, count);
281 
282         Vec res = F::backward(self_vec_i - self_vec_j, *grad_k, *dist_k, pvec);
283         res_vec_i = res_vec_i + res;
284         res_vec_j = res_vec_j - res;
285 
286         res_vec_j.store(res_j, count);
287       }
288 
289       res_vec_i.store(res_i, count);
290     }
291   }
292 
293   template <typename F>
run_backward_parallel_pdistat::native::__anon554798c00111::Dist294   static void run_backward_parallel_pdist(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
295     const int64_t n = self.size(0);
296     const int64_t m = self.size(1);
297     const int64_t gs = grad.stride(0);
298 
299     const scalar_t * const grad_start = grad.const_data_ptr<scalar_t>();
300     const scalar_t * const dist_start = dist.const_data_ptr<scalar_t>();
301     const scalar_t * const self_start = self.const_data_ptr<scalar_t>();
302     scalar_t * const res_start = result.data_ptr<scalar_t>();
303 
304     // The only way to parallelize and avoid locking requires parallelizing
305     // over the columns of the input, i.e. we compute the gradient for the
306     // first section of each vector independently of the second section, etc.
307     at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [p, n, m, gs, grad_start, dist_start, self_start, res_start](int64_t l, int64_t end) {
308       const Vec pvec(p);
309 
310       const scalar_t * self_l = self_start + l * Vec::size();
311       scalar_t * res_l = res_start + l * Vec::size();
312 
313       for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
314         backward_down_column_pdist<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
315       }
316     });
317     const int64_t remainder = m % Vec::size();
318     if (remainder) {
319       backward_down_column_pdist<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), n, m, gs, remainder);
320     }
321   }
322 
323   // Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
apply_backward_pdistat::native::__anon554798c00111::Dist324   static void apply_backward_pdist(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
325     result.fill_(0);
326     if (p == 0.0) {
327     } else if (p == 1.0) {
328       run_backward_parallel_pdist<odist_calc<Vec>>(result, grad, self, p, dist);
329     } else if (p < 2.0) {
330       run_backward_parallel_pdist<lttdist_calc>(result, grad, self, p, dist);
331     } else if (p == 2.0) {
332       run_backward_parallel_pdist<tdist_calc<Vec>>(result, grad, self, p, dist);
333     } else if (std::isinf(p)) {
334       run_backward_parallel_pdist<idist_calc<Vec>>(result, grad, self, p, dist);
335     } else {
336       run_backward_parallel_pdist<pdist_calc<Vec>>(result, grad, self, p, dist);
337     }
338   }
339 
apply_backward_cdistat::native::__anon554798c00111::Dist340   static void apply_backward_cdist(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
341     result.fill_(0);
342     if (p == 0.0) {
343     } else if (p == 1.0) {
344       run_backward_parallel_cdist<odist_calc<Vec>>(result, grad, x1, x2, p, dist);
345     } else if (p < 2.0) {
346       run_backward_parallel_cdist<lttdist_calc>(result, grad, x1, x2, p, dist);
347     } else if (p == 2.0) {
348       run_backward_parallel_cdist<tdist_calc<Vec>>(result, grad, x1, x2, p, dist);
349     } else if (std::isinf(p)) {
350       run_backward_parallel_cdist<idist_calc<Vec>>(result, grad, x1, x2, p, dist);
351     } else {
352       run_backward_parallel_cdist<pdist_calc<Vec>>(result, grad, x1, x2, p, dist);
353     }
354   }
355 
356 
357   template <typename F>
run_backward_parallel_cdistat::native::__anon554798c00111::Dist358   static void run_backward_parallel_cdist(Tensor& result, const Tensor & grad, const Tensor & t1, const Tensor & t2, const scalar_t p, const Tensor& dist) {
359     const int64_t r1 = t1.size(-2);
360     const int64_t r2 = t2.size(-2);
361     const int64_t m = t1.size(-1);
362     const int64_t d = result.size(0);
363     const int64_t l1_size = r1 * m;
364     const int64_t l2_size = r2 * m;
365     //current implementation supports only tensor that can be collapsed to 1D. However, to avoid checking if grad satisfies this assumption,
366     //we call .contiguous() on grad before backward, thus stride is guaranteed to be 1
367     //don't use grad.stride(-1), because if last dimension is 1, stride can be bogus.
368     const int64_t gs = 1;
369 
370     const scalar_t * const grad_start = grad.const_data_ptr<scalar_t>();
371     const scalar_t * const dist_start = dist.const_data_ptr<scalar_t>();
372     const scalar_t * const t1_start = t1.const_data_ptr<scalar_t>();
373     const scalar_t * const t2_start = t2.const_data_ptr<scalar_t>();
374     scalar_t * const res_start = result.data_ptr<scalar_t>();
375 
376     at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (16 * r1), [=](int64_t l, int64_t end) {
377       const Vec pvec(p);
378 
379       const scalar_t * i = t1_start + l * Vec::size();
380       const scalar_t * j = t2_start + l * Vec::size();
381       scalar_t * res_l = res_start + l * Vec::size();
382 
383       for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; i += Vec::size(), j += Vec::size(), res_l += Vec::size()) {
384         backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, d, gs, l1_size, l2_size);
385       }
386     });
387     const int64_t remainder = m % Vec::size();
388     if (remainder) {
389       backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, d, gs, l1_size, l2_size, remainder);
390     }
391   }
392 
393   template <typename F>
backward_down_column_cdistat::native::__anon554798c00111::Dist394   inline static void backward_down_column_cdist(const scalar_t * t1, const scalar_t * t2, scalar_t * res, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t r1, int64_t r2, int64_t m, int64_t d, int64_t gs, int64_t l1_size, int64_t l2_size, int64_t count = Vec::size()) {
395     const scalar_t * t1_end = t1 + l1_size;
396     const scalar_t * t2_end = t2 + l2_size;
397 
398     for (const auto l C10_UNUSED : c10::irange(d)) {
399       for (; t1 != t1_end; t1 += m, res += m) {
400         const Vec vec_t1 = Vec::loadu(t1, count);
401         Vec res_vec = Vec::loadu(res, count);
402 
403         for (const scalar_t * t2_curr = t2; t2_curr != t2_end; t2_curr += m, grad_k += gs, dist_k += 1) {
404           const Vec vec_t2 = Vec::loadu(t2_curr, count);
405           Vec res = F::backward(vec_t1 - vec_t2, *grad_k, *dist_k, pvec);
406           res_vec = res_vec + res;
407         }
408 
409         res_vec.store(res, count);
410       }
411       t1_end += l1_size;
412       t2_end += l2_size;
413       t2 += l2_size;
414     }
415   }
416 
417 };
418 
pdist_forward_kernel_impl(Tensor & result,const Tensor & self,const double p)419 void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
420   AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] {
421     Dist<scalar_t>::apply_pdist(result, self, p);
422   });
423 }
424 
pdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & self,const double p,const Tensor & dist)425 static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
426   AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] {
427     Dist<scalar_t>::apply_backward_pdist(result, grad, self, p, dist);
428   });
429 }
430 
cdist_kernel_impl(Tensor & result,const Tensor & x1,const Tensor & x2,const double p)431 static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
432   AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] {
433     Dist<scalar_t>::apply_cdist(result, x1, x2, p);
434   });
435 }
436 
cdist_backward_kernel_impl(Tensor & result,const Tensor & grad,const Tensor & x1,const Tensor & x2,const double p,const Tensor & dist)437 static void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
438   AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_backward", [&] {
439     Dist<scalar_t>::apply_backward_cdist(result, grad, x1, x2, p, dist);
440   });
441 }
442 
443 
444 }  // anonymous namespace
445 
446 REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
447 REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
448 REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
449 REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);
450 
451 }  // namespace at::native
452