1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15
16 bool check_pdist_args(const Tensor& in, double p, const Tensor& out);
17
18 void get_pdist_out_target_size(
19 const Tensor& in,
20 Tensor::SizesType* out_sizes,
21 size_t* out_ndim);
22
23 template <typename CTYPE, typename Norm>
pdist(const Tensor & in,Tensor & out,double p)24 void pdist(const Tensor& in, Tensor& out, double p) {
25 const CTYPE* in_data = in.const_data_ptr<CTYPE>();
26 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
27
28 size_t n = in.size(0);
29 size_t m = in.size(1);
30
31 size_t out_ix = 0;
32 for (size_t i = 0; i < n; ++i) {
33 for (size_t j = i + 1; j < n; ++j) {
34 const CTYPE* row_i = in_data + i * m;
35 const CTYPE* row_j = in_data + j * m;
36 CTYPE agg = 0;
37 for (size_t k = 0; k < m; ++k) {
38 CTYPE diff = std::abs(row_i[k] - row_j[k]);
39 agg = Norm::reduce(agg, Norm::map(diff, p));
40 }
41 out_data[out_ix++] = Norm::finish(agg, p);
42 }
43 }
44 }
45
46 template <typename CTYPE>
47 struct L0 {
mapL048 static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
49 return diff == 0 ? 0 : 1;
50 }
reduceL051 static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
52 return agg + up;
53 }
finishL054 static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
55 return agg;
56 }
57 };
58
59 template <typename CTYPE>
60 struct L1 {
mapL161 static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
62 return diff;
63 }
reduceL164 static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
65 return agg + up;
66 }
finishL167 static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
68 return agg;
69 }
70 };
71
72 template <typename CTYPE>
73 struct L2 {
mapL274 static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
75 return diff * diff;
76 }
reduceL277 static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
78 return agg + up;
79 }
finishL280 static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
81 return std::sqrt(agg);
82 }
83 };
84
85 template <typename CTYPE>
86 struct Lp {
mapLp87 static inline CTYPE map(const CTYPE& diff, const CTYPE& p) {
88 return std::pow(diff, p);
89 }
reduceLp90 static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
91 return agg + up;
92 }
finishLp93 static inline CTYPE finish(const CTYPE& agg, const CTYPE& p) {
94 return std::pow(agg, 1.0 / p);
95 }
96 };
97
98 template <typename CTYPE>
99 struct Linf {
mapLinf100 static inline CTYPE map(const CTYPE& diff, const CTYPE&) {
101 return diff;
102 }
reduceLinf103 static inline CTYPE reduce(const CTYPE& agg, const CTYPE& up) {
104 return std::max(agg, up);
105 }
finishLinf106 static inline CTYPE finish(const CTYPE& agg, const CTYPE&) {
107 return agg;
108 }
109 };
110
111 template <typename CTYPE>
pdist(const Tensor & in,Tensor & out,double p)112 void pdist(const Tensor& in, Tensor& out, double p) {
113 if (p == 0.0) {
114 pdist<CTYPE, L0<CTYPE>>(in, out, p);
115 } else if (p == 1.0) {
116 pdist<CTYPE, L1<CTYPE>>(in, out, p);
117 } else if (p == 2.0) {
118 pdist<CTYPE, L2<CTYPE>>(in, out, p);
119 } else if (p == INFINITY) {
120 pdist<CTYPE, Linf<CTYPE>>(in, out, p);
121 } else {
122 pdist<CTYPE, Lp<CTYPE>>(in, out, p);
123 }
124 }
125
126 bool check_cdist_args(
127 const Tensor& x1,
128 const Tensor& x2,
129 double p,
130 optional<int64_t> compute_mode,
131 const Tensor& out);
132
133 } // namespace executor
134 } // namespace torch
135