1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
pow_Tensor_Tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)19 Tensor& pow_Tensor_Tensor_out(
20 KernelRuntimeContext& ctx,
21 const Tensor& a,
22 const Tensor& b,
23 Tensor& out) {
24 // Common Dtype
25 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
26
27 // Check Common Dtype
28 ET_KERNEL_CHECK(
29 ctx,
30 (canCast(common_type, out.scalar_type()) &&
31 common_type != ScalarType::Bool),
32 InvalidArgument,
33 out);
34
35 // Check Dim Order
36 ET_KERNEL_CHECK(
37 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
38
39 // Resize
40 ET_KERNEL_CHECK(
41 ctx,
42 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
43 InvalidArgument,
44 out);
45
46 // Compute Dtype
47 ScalarType compute_type = utils::get_compute_type(common_type);
48 if (compute_type != ScalarType::Float) {
49 compute_type = ScalarType::Double;
50 }
51
52 // @lint-ignore CLANGTIDY facebook-hte-CArray
53 static constexpr const char op_name[] = "pow.Tensor_Tensor_out";
54
55 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57 [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
58 return std::pow(val_a, val_b);
59 },
60 ctx,
61 a,
62 utils::SupportedTensorDtypes::REALHBBF16,
63 b,
64 utils::SupportedTensorDtypes::REALHBBF16,
65 out,
66 utils::SupportedTensorDtypes::REALHBF16);
67 });
68
69 return out;
70 }
71
pow_Tensor_Scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)72 Tensor& pow_Tensor_Scalar_out(
73 KernelRuntimeContext& ctx,
74 const Tensor& a,
75 const Scalar& b,
76 Tensor& out) {
77 // Common Dtype
78 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
79
80 // Check Common Dtype
81 ET_KERNEL_CHECK(
82 ctx,
83 (canCast(common_type, out.scalar_type()) &&
84 common_type != ScalarType::Bool),
85 InvalidArgument,
86 out);
87
88 // Check Dim Order
89 ET_KERNEL_CHECK(
90 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
91
92 // Resize
93 ET_KERNEL_CHECK(
94 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
95
96 // Compute Dtype
97 ScalarType compute_type = utils::get_compute_type(common_type);
98 if (compute_type != ScalarType::Float) {
99 compute_type = ScalarType::Double;
100 }
101
102 // @lint-ignore CLANGTIDY facebook-hte-CArray
103 static constexpr const char op_name[] = "pow.Tensor_Scalar_out";
104
105 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
106 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
107 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
108 [val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); },
109 ctx,
110 a,
111 utils::SupportedTensorDtypes::REALHBBF16,
112 out,
113 utils::SupportedTensorDtypes::REALHBF16);
114 });
115
116 return out;
117 }
118
pow_Scalar_out(KernelRuntimeContext & ctx,const Scalar & a,const Tensor & b,Tensor & out)119 Tensor& pow_Scalar_out(
120 KernelRuntimeContext& ctx,
121 const Scalar& a,
122 const Tensor& b,
123 Tensor& out) {
124 // Common Dtype
125 ScalarType common_type = utils::promote_type_with_scalar(b.scalar_type(), a);
126
127 // Check Common Dtype
128 ET_KERNEL_CHECK(
129 ctx,
130 (canCast(common_type, out.scalar_type()) &&
131 common_type != ScalarType::Bool),
132 InvalidArgument,
133 out);
134
135 // Check Dim Order
136 ET_KERNEL_CHECK(
137 ctx, tensors_have_same_dim_order(b, out), InvalidArgument, out);
138
139 // Resize
140 ET_KERNEL_CHECK(
141 ctx, resize_tensor(out, b.sizes()) == Error::Ok, InvalidArgument, out);
142
143 // Compute Dtype
144 ScalarType compute_type = utils::get_compute_type(common_type);
145 if (compute_type != ScalarType::Float) {
146 compute_type = ScalarType::Double;
147 }
148
149 // @lint-ignore CLANGTIDY facebook-hte-CArray
150 static constexpr const char op_name[] = "pow.Scalar_out";
151
152 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
153 const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
154 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
155 [val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); },
156 ctx,
157 b,
158 utils::SupportedTensorDtypes::REALHBBF16,
159 out,
160 utils::SupportedTensorDtypes::REALHBF16);
161 });
162
163 return out;
164 }
165
166 } // namespace native
167 } // namespace executor
168 } // namespace torch
169