xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_pow.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 
pow_Tensor_Tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)19 Tensor& pow_Tensor_Tensor_out(
20     KernelRuntimeContext& ctx,
21     const Tensor& a,
22     const Tensor& b,
23     Tensor& out) {
24   // Common Dtype
25   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
26 
27   // Check Common Dtype
28   ET_KERNEL_CHECK(
29       ctx,
30       (canCast(common_type, out.scalar_type()) &&
31        common_type != ScalarType::Bool),
32       InvalidArgument,
33       out);
34 
35   // Check Dim Order
36   ET_KERNEL_CHECK(
37       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
38 
39   // Resize
40   ET_KERNEL_CHECK(
41       ctx,
42       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
43       InvalidArgument,
44       out);
45 
46   // Compute Dtype
47   ScalarType compute_type = utils::get_compute_type(common_type);
48   if (compute_type != ScalarType::Float) {
49     compute_type = ScalarType::Double;
50   }
51 
52   // @lint-ignore CLANGTIDY facebook-hte-CArray
53   static constexpr const char op_name[] = "pow.Tensor_Tensor_out";
54 
55   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
56     utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
57         [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
58           return std::pow(val_a, val_b);
59         },
60         ctx,
61         a,
62         utils::SupportedTensorDtypes::REALHBBF16,
63         b,
64         utils::SupportedTensorDtypes::REALHBBF16,
65         out,
66         utils::SupportedTensorDtypes::REALHBF16);
67   });
68 
69   return out;
70 }
71 
pow_Tensor_Scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)72 Tensor& pow_Tensor_Scalar_out(
73     KernelRuntimeContext& ctx,
74     const Tensor& a,
75     const Scalar& b,
76     Tensor& out) {
77   // Common Dtype
78   ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
79 
80   // Check Common Dtype
81   ET_KERNEL_CHECK(
82       ctx,
83       (canCast(common_type, out.scalar_type()) &&
84        common_type != ScalarType::Bool),
85       InvalidArgument,
86       out);
87 
88   // Check Dim Order
89   ET_KERNEL_CHECK(
90       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
91 
92   // Resize
93   ET_KERNEL_CHECK(
94       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
95 
96   // Compute Dtype
97   ScalarType compute_type = utils::get_compute_type(common_type);
98   if (compute_type != ScalarType::Float) {
99     compute_type = ScalarType::Double;
100   }
101 
102   // @lint-ignore CLANGTIDY facebook-hte-CArray
103   static constexpr const char op_name[] = "pow.Tensor_Scalar_out";
104 
105   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
106     const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
107     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
108         [val_b](const CTYPE_COMPUTE val_a) { return std::pow(val_a, val_b); },
109         ctx,
110         a,
111         utils::SupportedTensorDtypes::REALHBBF16,
112         out,
113         utils::SupportedTensorDtypes::REALHBF16);
114   });
115 
116   return out;
117 }
118 
pow_Scalar_out(KernelRuntimeContext & ctx,const Scalar & a,const Tensor & b,Tensor & out)119 Tensor& pow_Scalar_out(
120     KernelRuntimeContext& ctx,
121     const Scalar& a,
122     const Tensor& b,
123     Tensor& out) {
124   // Common Dtype
125   ScalarType common_type = utils::promote_type_with_scalar(b.scalar_type(), a);
126 
127   // Check Common Dtype
128   ET_KERNEL_CHECK(
129       ctx,
130       (canCast(common_type, out.scalar_type()) &&
131        common_type != ScalarType::Bool),
132       InvalidArgument,
133       out);
134 
135   // Check Dim Order
136   ET_KERNEL_CHECK(
137       ctx, tensors_have_same_dim_order(b, out), InvalidArgument, out);
138 
139   // Resize
140   ET_KERNEL_CHECK(
141       ctx, resize_tensor(out, b.sizes()) == Error::Ok, InvalidArgument, out);
142 
143   // Compute Dtype
144   ScalarType compute_type = utils::get_compute_type(common_type);
145   if (compute_type != ScalarType::Float) {
146     compute_type = ScalarType::Double;
147   }
148 
149   // @lint-ignore CLANGTIDY facebook-hte-CArray
150   static constexpr const char op_name[] = "pow.Scalar_out";
151 
152   ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
153     const CTYPE_COMPUTE val_a = utils::scalar_to<CTYPE_COMPUTE>(a);
154     utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
155         [val_a](const CTYPE_COMPUTE val_b) { return std::pow(val_a, val_b); },
156         ctx,
157         b,
158         utils::SupportedTensorDtypes::REALHBBF16,
159         out,
160         utils::SupportedTensorDtypes::REALHBF16);
161   });
162 
163   return out;
164 }
165 
166 } // namespace native
167 } // namespace executor
168 } // namespace torch
169