xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_atan2.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cmath>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 
17 namespace {
18 
get_common_type(ScalarType a_type,ScalarType b_type)19 ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
20   if (isFloatingType(a_type) && isFloatingType(b_type)) {
21     return promoteTypes(a_type, b_type);
22   } else if (isFloatingType(a_type)) {
23     return a_type;
24   } else if (isFloatingType(b_type)) {
25     return b_type;
26   }
27   return ScalarType::Float;
28 }
29 
30 } // namespace
31 
atan2_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)32 Tensor& atan2_out(
33     KernelRuntimeContext& ctx,
34     const Tensor& a,
35     const Tensor& b,
36     Tensor& out) {
37   // Common Dtype
38   ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
39 
40   // Check Dim Order
41   ET_KERNEL_CHECK(
42       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43 
44   // Resize
45   ET_KERNEL_CHECK(
46       ctx,
47       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
48       InvalidArgument,
49       out);
50 
51   // Compute Dtype
52   ScalarType compute_type = utils::get_compute_type(common_type);
53 
54   // @lint-ignore CLANGTIDY facebook-hte-CArray
55   static constexpr const char op_name[] = "atan2.out";
56 
57   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
59         [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60           return std::atan2(val_a, val_b);
61         },
62         ctx,
63         a,
64         utils::SupportedTensorDtypes::REALHBBF16,
65         b,
66         utils::SupportedTensorDtypes::REALHBBF16,
67         out,
68         utils::SupportedTensorDtypes::FLOATHBF16);
69   });
70 
71   return out;
72 }
73 
74 } // namespace native
75 } // namespace executor
76 } // namespace torch
77