1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10 #include <tuple>
11
12 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19 namespace {
20
21 template <typename CTYPE>
upper_bound()22 constexpr CTYPE upper_bound() {
23 using lim = std::numeric_limits<CTYPE>;
24 return lim::has_infinity ? lim::infinity() : lim::max();
25 }
26
27 } // namespace
28
29 using ScalarType = exec_aten::ScalarType;
30 using SizesType = exec_aten::SizesType;
31 using Tensor = exec_aten::Tensor;
32
min_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t dim,bool keepdim,Tensor & min,Tensor & min_indices)33 std::tuple<Tensor&, Tensor&> min_out(
34 KernelRuntimeContext& ctx,
35 const Tensor& in,
36 int64_t dim,
37 bool keepdim,
38 Tensor& min,
39 Tensor& min_indices) {
40 (void)ctx;
41
42 ET_KERNEL_CHECK(
43 ctx,
44 check_min_max_args(in, dim, keepdim, min, min_indices),
45 InvalidArgument,
46 (std::tuple<Tensor&, Tensor&>({min, min_indices})));
47
48 ET_KERNEL_CHECK(
49 ctx,
50 resize_reduction_out(in, dim, keepdim, min) == Error::Ok,
51 InvalidArgument,
52 (std::tuple<Tensor&, Tensor&>({min, min_indices})));
53
54 ET_KERNEL_CHECK(
55 ctx,
56 resize_tensor(min_indices, min.sizes()) == Error::Ok,
57 InvalidArgument,
58 (std::tuple<Tensor&, Tensor&>({min, min_indices})));
59
60 ET_KERNEL_CHECK(
61 ctx,
62 tensors_have_same_dim_order(in, min),
63 InvalidArgument,
64 (std::tuple<Tensor&, Tensor&>({min, min_indices})));
65
66 ET_KERNEL_CHECK(
67 ctx,
68 tensor_is_default_dim_order(min_indices),
69 InvalidArgument,
70 (std::tuple<Tensor&, Tensor&>({min, min_indices})));
71
72 ET_KERNEL_CHECK(
73 ctx,
74 tensor_is_default_dim_order(in),
75 InvalidArgument,
76 (std::tuple<Tensor&, Tensor&>({min, min_indices})));
77
78 dim = dim < 0 ? dim + in.dim() : dim;
79
80 ET_SWITCH_REAL_TYPES_AND(
81 Bool, in.scalar_type(), ctx, "min.dim_min", CTYPE, [&]() {
82 CTYPE* min_data = min.mutable_data_ptr<CTYPE>();
83 long* min_indices_data = min_indices.mutable_data_ptr<long>();
84
85 for (size_t out_ix = 0; out_ix < min.numel(); ++out_ix) {
86 std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
87 [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
88 if (!std::isnan(acc_val) && (std::isnan(v) || v < acc_val)) {
89 acc_val = v;
90 acc_ix = ix;
91 }
92 return std::tuple<CTYPE, long>{acc_val, acc_ix};
93 },
94 in,
95 dim,
96 out_ix);
97 min_data[out_ix] = std::get<0>(acc);
98 min_indices_data[out_ix] = std::get<1>(acc);
99 }
100 });
101
102 return {min, min_indices};
103 }
104
105 Tensor&
min_unary_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)106 min_unary_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
107 (void)ctx;
108
109 ET_KERNEL_CHECK(
110 ctx, resize_tensor(out, {}) == Error::Ok, InvalidArgument, out);
111
112 ET_KERNEL_CHECK(
113 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
114
115 ScalarType in_type = in.scalar_type();
116 ScalarType out_type = out.scalar_type();
117
118 ET_KERNEL_CHECK(ctx, canCast(in_type, out_type), InvalidArgument, out);
119
120 constexpr auto name = "min.unary_out";
121
122 ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
123 ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
124 const auto data_in = in.const_data_ptr<CTYPE_IN>();
125 auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
126 data_out[0] = upper_bound<CTYPE_OUT>();
127 for (auto i = 0; i < in.numel(); ++i) {
128 CTYPE_OUT val = static_cast<CTYPE_OUT>(data_in[i]);
129 if (std::isnan(val)) {
130 data_out[0] = val;
131 break;
132 }
133 if (val < data_out[0]) {
134 data_out[0] = val;
135 }
136 }
137 });
138 });
139
140 return out;
141 }
142
143 } // namespace native
144 } // namespace executor
145 } // namespace torch
146