1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10 #include <tuple>
11
12 #include <executorch/kernels/portable/cpu/util/reduce_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 #include <executorch/runtime/platform/assert.h>
15
16 namespace torch {
17 namespace executor {
18 namespace native {
19
20 using exec_aten::optional;
21 using exec_aten::Tensor;
22
argmax_out(KernelRuntimeContext & ctx,const Tensor & in,optional<int64_t> dim,bool keepdim,Tensor & out)23 Tensor& argmax_out(
24 KernelRuntimeContext& ctx,
25 const Tensor& in,
26 optional<int64_t> dim,
27 bool keepdim,
28 Tensor& out) {
29 (void)ctx;
30
31 ET_KERNEL_CHECK(
32 ctx,
33 check_argmin_argmax_args(in, dim, keepdim, out),
34 InvalidArgument,
35 out);
36
37 ET_KERNEL_CHECK(
38 ctx,
39 resize_reduction_out(in, dim, keepdim, out) == Error::Ok,
40 InvalidArgument,
41 out);
42
43 ET_KERNEL_CHECK(
44 ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45
46 ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "argmax.out", CTYPE, [&] {
47 long* out_data = out.mutable_data_ptr<long>();
48
49 for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
50 std::tuple<CTYPE, long> acc = reduce_over_dim<CTYPE>(
51 [](CTYPE v, long ix, CTYPE acc_val, long acc_ix) {
52 if (!std::isnan(acc_val) && (std::isnan(v) || v > acc_val)) {
53 acc_val = v;
54 acc_ix = ix;
55 }
56 return std::tuple<CTYPE, long>{acc_val, acc_ix};
57 },
58 in,
59 dim,
60 out_ix);
61 out_data[out_ix] = std::get<1>(acc);
62 }
63 });
64
65 return out;
66 }
67
68 } // namespace native
69 } // namespace executor
70 } // namespace torch
71