1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/distance_util.h>
10
11 namespace torch {
12 namespace executor {
13
check_pdist_args(const Tensor & in,double p,const Tensor & out)14 bool check_pdist_args(const Tensor& in, double p, const Tensor& out) {
15 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
16 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
17 ET_LOG_MSG_AND_RETURN_IF_FALSE(
18 p >= 0, "pdist only supports non-negative p values");
19 return true;
20 }
21
get_pdist_out_target_size(const Tensor & in,Tensor::SizesType * out_sizes,size_t * out_ndim)22 void get_pdist_out_target_size(
23 const Tensor& in,
24 Tensor::SizesType* out_sizes,
25 size_t* out_ndim) {
26 *out_ndim = 1;
27 size_t n = in.size(0);
28 out_sizes[0] = n * (n - 1) / 2;
29 }
30
check_cdist_args(const Tensor & x1,const Tensor & x2,double p,optional<int64_t> compute_mode,const Tensor & out)31 bool check_cdist_args(
32 const Tensor& x1,
33 const Tensor& x2,
34 double p,
35 optional<int64_t> compute_mode,
36 const Tensor& out) {
37 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(x1, x2));
38 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(x1, out));
39 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(x1, 2));
40 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(x2, 2));
41 ET_LOG_AND_RETURN_IF_FALSE(
42 tensors_have_same_size_at_dims(x1, x1.dim() - 1, x2, x2.dim() - 1));
43 ET_LOG_MSG_AND_RETURN_IF_FALSE(
44 p >= 0, "cdist only supports non-negative p values");
45 if (compute_mode.has_value()) {
46 int64_t mode = compute_mode.value();
47 ET_LOG_MSG_AND_RETURN_IF_FALSE(
48 mode >= 0 && mode <= 2,
49 "possible modes: 0, 1, 2, but was: %" PRId64,
50 mode);
51 }
52 return true;
53 }
54
55 } // namespace executor
56 } // namespace torch
57