xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_min.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 #include <tuple>
11 
12 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15 
16 namespace torch {
17 namespace executor {
18 namespace native {
19 namespace {
20 
21 template <typename CTYPE>
upper_bound()22 constexpr CTYPE upper_bound() {
23   using lim = std::numeric_limits<CTYPE>;
24   return lim::has_infinity ? lim::infinity() : lim::max();
25 }
26 
27 } // namespace
28 
29 using ScalarType = exec_aten::ScalarType;
30 using SizesType = exec_aten::SizesType;
31 using Tensor = exec_aten::Tensor;
32 
min_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,bool keepdim,Tensor & min,Tensor & min_indices)33 std::tuple<Tensor&, Tensor&> min_out(
34     KernelRuntimeContext& ctx,
35     const Tensor& in,
36     int64_t dim,
37     bool keepdim,
38     Tensor& min,
39     Tensor& min_indices) {
40   (void)ctx;
41 
42   ET_KERNEL_CHECK(
43       ctx,
44       check_min_max_args(in, dim, keepdim, min, min_indices),
45       InvalidArgument,
46       (std::tuple<Tensor&, Tensor&>({min, min_indices})));
47 
48   ET_KERNEL_CHECK(
49       ctx,
50       resize_reduction_out(in, dim, keepdim, min) == Error::Ok,
51       InvalidArgument,
52       (std::tuple<Tensor&, Tensor&>({min, min_indices})));
53 
54   ET_KERNEL_CHECK(
55       ctx,
56       resize_tensor(min_indices, min.sizes()) == Error::Ok,
57       InvalidArgument,
58       (std::tuple<Tensor&, Tensor&>({min, min_indices})));
59 
60   ET_KERNEL_CHECK(
61       ctx,
62       tensors_have_same_dim_order(in, min),
63       InvalidArgument,
64       (std::tuple<Tensor&, Tensor&>({min, min_indices})));
65 
66   ET_KERNEL_CHECK(
67       ctx,
68       tensor_is_default_dim_order(min_indices),
69       InvalidArgument,
70       (std::tuple<Tensor&, Tensor&>({min, min_indices})));
71 
72   ET_KERNEL_CHECK(
73       ctx,
74       tensor_is_default_dim_order(in),
75       InvalidArgument,
76       (std::tuple<Tensor&, Tensor&>({min, min_indices})));
77 
78   dim = dim < 0 ? dim + in.dim() : dim;
79 
80   ET_SWITCH_REAL_TYPES_AND(
81       Bool, in.scalar_type(), ctx, "min.dim_min", CTYPE, [&]() {
82         CTYPE* min_data = min.mutable_data_ptr<CTYPE>();
83         long* min_indices_data = min_indices.mutable_data_ptr<long>();
84 
85         for (size_t out_ix = 0; out_ix < min.numel(); ++out_ix) {
86           std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
87               [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
88                 if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
89                   acc_val = v;
90                   acc_ix = ix;
91                 }
92                 return std::tuple<CTYPE, long>{acc_val, acc_ix};
93               },
94               in,
95               dim,
96               out_ix);
97           min_data[out_ix] = std::get<0>(acc);
98           min_indices_data[out_ix] = std::get<1>(acc);
99         }
100       });
101 
102   return {min, min_indices};
103 }
104 
105 Tensor&
min_unary_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)106 min_unary_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
107   (void)ctx;
108 
109   ET_KERNEL_CHECK(
110       ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
111 
112   ET_KERNEL_CHECK(
113       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
114 
115   ScalarType in_type = in.scalar_type();
116   ScalarType out_type = out.scalar_type();
117 
118   ET_KERNEL_CHECK(ctx, canCast(in_type, out_type), InvalidArgument, out);
119 
120   constexpr auto name = "min.unary_out";
121 
122   ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
123     ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
124       const auto data_in = in.const_data_ptr<CTYPE_IN>();
125       auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
126       data_out[0] = upper_bound<CTYPE_OUT>();
127       for (auto i = 0; i < in.numel(); ++i) {
128         CTYPE_OUT val = static_cast<CTYPE_OUT>(data_in[i]);
129         if (std::isnan(val)) {
130           data_out[0] = val;
131           break;
132         }
133         if (val < data_out[0]) {
134           data_out[0] = val;
135         }
136       }
137     });
138   });
139 
140   return out;
141 }
142 
143 } // namespace native
144 } // namespace executor
145 } // namespace torch
146