xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/distance_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/distance_util.h>
10 
11 namespace torch {
12 namespace executor {
13 
check_pdist_args(const Tensor & in,double p,const Tensor & out)14 bool check_pdist_args(const Tensor& in, double p, const Tensor& out) {
15   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
16   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
17   ET_LOG_MSG_AND_RETURN_IF_FALSE(
18       p >= 0, "pdist only supports non-negative p values");
19   return true;
20 }
21 
get_pdist_out_target_size(const Tensor & in,Tensor::SizesType * out_sizes,size_t * out_ndim)22 void get_pdist_out_target_size(
23     const Tensor& in,
24     Tensor::SizesType* out_sizes,
25     size_t* out_ndim) {
26   *out_ndim = 1;
27   size_t n = in.size(0);
28   out_sizes[0] = n * (n - 1) / 2;
29 }
30 
check_cdist_args(const Tensor & x1,const Tensor & x2,double p,optional<int64_t> compute_mode,const Tensor & out)31 bool check_cdist_args(
32     const Tensor& x1,
33     const Tensor& x2,
34     double p,
35     optional<int64_t> compute_mode,
36     const Tensor& out) {
37   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(x1, x2));
38   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(x1, out));
39   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(x1, 2));
40   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_greater_or_equal_to(x2, 2));
41   ET_LOG_AND_RETURN_IF_FALSE(
42       tensors_have_same_size_at_dims(x1, x1.dim() - 1, x2, x2.dim() - 1));
43   ET_LOG_MSG_AND_RETURN_IF_FALSE(
44       p >= 0, "cdist only supports non-negative p values");
45   if (compute_mode.has_value()) {
46     int64_t mode = compute_mode.value();
47     ET_LOG_MSG_AND_RETURN_IF_FALSE(
48         mode >= 0 && mode <= 2,
49         "possible modes: 0, 1, 2, but was: %" PRId64,
50         mode);
51   }
52   return true;
53 }
54 
55 } // namespace executor
56 } // namespace torch
57