1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cmath>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16
17 namespace {
18
get_common_type(ScalarType a_type,ScalarType b_type)19 ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
20 if (isFloatingType(a_type) && isFloatingType(b_type)) {
21 return promoteTypes(a_type, b_type);
22 } else if (isFloatingType(a_type)) {
23 return a_type;
24 } else if (isFloatingType(b_type)) {
25 return b_type;
26 }
27 return ScalarType::Float;
28 }
29
30 } // namespace
31
atan2_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)32 Tensor& atan2_out(
33 KernelRuntimeContext& ctx,
34 const Tensor& a,
35 const Tensor& b,
36 Tensor& out) {
37 // Common Dtype
38 ScalarType common_type = get_common_type(a.scalar_type(), b.scalar_type());
39
40 // Check Dim Order
41 ET_KERNEL_CHECK(
42 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
43
44 // Resize
45 ET_KERNEL_CHECK(
46 ctx,
47 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
48 InvalidArgument,
49 out);
50
51 // Compute Dtype
52 ScalarType compute_type = utils::get_compute_type(common_type);
53
54 // @lint-ignore CLANGTIDY facebook-hte-CArray
55 static constexpr const char op_name[] = "atan2.out";
56
57 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
59 [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
60 return std::atan2(val_a, val_b);
61 },
62 ctx,
63 a,
64 utils::SupportedTensorDtypes::REALHBBF16,
65 b,
66 utils::SupportedTensorDtypes::REALHBBF16,
67 out,
68 utils::SupportedTensorDtypes::FLOATHBF16);
69 });
70
71 return out;
72 }
73
74 } // namespace native
75 } // namespace executor
76 } // namespace torch
77