xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_logit.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/kernels/portable/cpu/util/functional_util.h>
12 #include <executorch/runtime/kernel/kernel_includes.h>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 
18 using exec_aten::Tensor;
19 
logit_out(KernelRuntimeContext & ctx,const Tensor & in,exec_aten::optional<double> eps,Tensor & out)20 Tensor& logit_out(
21     KernelRuntimeContext& ctx,
22     const Tensor& in,
23     exec_aten::optional<double> eps,
24     Tensor& out) {
25   (void)ctx;
26 
27   // Resize for dynamic shape
28   ET_KERNEL_CHECK(
29       ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
30 
31   ET_KERNEL_CHECK(
32       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
33 
34   ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);
35 
36   ScalarType in_type = in.scalar_type();
37   ScalarType out_type = out.scalar_type();
38   ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "logit.out", CTYPE_IN, [&] {
39     ET_SWITCH_FLOAT_TYPES(out_type, ctx, "logit.out", CTYPE_OUT, [&] {
40       apply_unary_map_fn(
41           [eps](const CTYPE_IN val_in) {
42             CTYPE_OUT xi = static_cast<CTYPE_OUT>(val_in);
43             if (eps.has_value()) {
44               if (xi < eps.value()) {
45                 xi = eps.value();
46               } else if (xi > 1 - eps.value()) {
47                 xi = 1 - eps.value();
48               }
49             }
50             return static_cast<CTYPE_OUT>(
51                 log(xi / (static_cast<CTYPE_OUT>(1.0) - xi)));
52           },
53           in.const_data_ptr<CTYPE_IN>(),
54           out.mutable_data_ptr<CTYPE_OUT>(),
55           in.numel());
56     });
57   });
58   return out;
59 }
60 
61 } // namespace native
62 } // namespace executor
63 } // namespace torch
64