xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_sigmoid.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
12 #include <executorch/kernels/portable/cpu/util/functional_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
19 using Tensor = exec_aten::Tensor;
20 
sigmoid_out(KernelRuntimeContext & ctx,const Tensor & in,Tensor & out)21 Tensor& sigmoid_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
22   (void)ctx;
23 
24   ET_KERNEL_CHECK(
25       ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out);
26   ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out);
27 
28   ET_KERNEL_CHECK(
29       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
30 
31   // Resize for dynamic shape
32   ET_KERNEL_CHECK_MSG(
33       ctx,
34       resize_tensor(out, in.sizes()) == Error::Ok,
35       InvalidArgument,
36       out,
37       "Failed to resize output tensor.");
38 
39   ScalarType compute_type =
40       executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type()
41                                                             : ScalarType::Float;
42   compute_type = utils::get_compute_type(compute_type);
43 
44   // @lint-ignore CLANGTIDY facebook-hte-CArray
45   static constexpr const char op_name[] = "sigmoid.out";
46 
47   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
48     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
49         [](const CTYPE_COMPUTE val_in) {
50           CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
51               (static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
52           return out_val;
53         },
54         ctx,
55         in,
56         utils::SupportedTensorDtypes::REALHBBF16,
57         out,
58         utils::SupportedTensorDtypes::FLOATHBF16);
59   });
60 
61   return out;
62 }
63 
64 } // namespace native
65 } // namespace executor
66 } // namespace torch
67