xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/distance_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 
16 bool check_pdist_args(const Tensor& in, double p, const Tensor& out);
17 
18 void get_pdist_out_target_size(
19     const Tensor& in,
20     Tensor::SizesType* out_sizes,
21     size_t* out_ndim);
22 
23 template <typename CTYPE, typename Norm>
pdist(const Tensor & in,Tensor & out,double p)24 void pdist(const Tensor& in, Tensor& out, double p) {
25   const CTYPE* in_data = in.const_data_ptr<CTYPE>();
26   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
27 
28   size_t n = in.size(0);
29   size_t m = in.size(1);
30 
31   size_t out_ix = 0;
32   for (size_t i = 0; i < n; ++i) {
33     for (size_t j = i + 1; j < n; ++j) {
34       const CTYPE* row_i = in_data + i * m;
35       const CTYPE* row_j = in_data + j * m;
36       CTYPE agg = 0;
37       for (size_t k = 0; k < m; ++k) {
38         CTYPE diff = std::abs(row_i[k] - row_j[k]);
39         agg = Norm::reduce(agg, Norm::map(diff, p));
40       }
41       out_data[out_ix++] = Norm::finish(agg, p);
42     }
43   }
44 }
45 
46 template <typename CTYPE>
47 struct L0 {
mapL048   static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
49     return diff == 0 ? 0 : 1;
50   }
reduceL051   static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
52     return agg + up;
53   }
finishL054   static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
55     return agg;
56   }
57 };
58 
59 template <typename CTYPE>
60 struct L1 {
mapL161   static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
62     return diff;
63   }
reduceL164   static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
65     return agg + up;
66   }
finishL167   static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
68     return agg;
69   }
70 };
71 
72 template <typename CTYPE>
73 struct L2 {
mapL274   static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
75     return diff * diff;
76   }
reduceL277   static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
78     return agg + up;
79   }
finishL280   static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
81     return std::sqrt(agg);
82   }
83 };
84 
85 template <typename CTYPE>
86 struct Lp {
mapLp87   static inline CTYPE map(const CTYPE& diff, const CTYPE& p) {
88     return std::pow(diff, p);
89   }
reduceLp90   static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
91     return agg + up;
92   }
finishLp93   static inline CTYPE finish(const CTYPE& agg, const CTYPE& p) {
94     return std::pow(agg, 1.0 / p);
95   }
96 };
97 
98 template <typename CTYPE>
99 struct Linf {
mapLinf100   static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
101     return diff;
102   }
reduceLinf103   static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
104     return std::max(agg, up);
105   }
finishLinf106   static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
107     return agg;
108   }
109 };
110 
111 template <typename CTYPE>
pdist(const Tensor & in,Tensor & out,double p)112 void pdist(const Tensor& in, Tensor& out, double p) {
113   if (p == 0.0) {
114     pdist<CTYPE, L0<CTYPE>>(in, out, p);
115   } else if (p == 1.0) {
116     pdist<CTYPE, L1<CTYPE>>(in, out, p);
117   } else if (p == 2.0) {
118     pdist<CTYPE, L2<CTYPE>>(in, out, p);
119   } else if (p == INFINITY) {
120     pdist<CTYPE, Linf<CTYPE>>(in, out, p);
121   } else {
122     pdist<CTYPE, Lp<CTYPE>>(in, out, p);
123   }
124 }
125 
126 bool check_cdist_args(
127     const Tensor& x1,
128     const Tensor& x2,
129     double p,
130     optional<int64_t> compute_mode,
131     const Tensor& out);
132 
133 } // namespace executor
134 } // namespace torch
135