xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/pattern/bitwise_op.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14 
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace internal {
19 
20 #define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \
21   template <typename T>                           \
22   T name(const T val_a, const T val_b) {          \
23     return val_a op val_b;                        \
24   }
25 
26 DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_and, &)
27 DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_or, |)
28 DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_xor, ^)
29 
30 template <typename T>
31 using bitwise_fn = T (*)(const T, const T);
32 
33 template <typename T, const char* op_name>
get_bitwise_fn()34 constexpr bitwise_fn<T> get_bitwise_fn() {
35   std::string_view op = op_name;
36   if (op == "bitwise_and.Tensor_out" || op == "bitwise_and.Scalar_out") {
37     return bitwise_and;
38   }
39   if (op == "bitwise_or.Tensor_out" || op == "bitwise_or.Scalar_out") {
40     return bitwise_or;
41   }
42   if (op == "bitwise_xor.Tensor_out" || op == "bitwise_xor.Scalar_out") {
43     return bitwise_xor;
44   }
45   return nullptr;
46 };
47 
48 template <typename T, const char* op_name>
49 struct BitwiseFnForOp {
50   static constexpr auto value = get_bitwise_fn<T, op_name>();
51   static_assert(value != nullptr, "unknown op_name!");
52 };
53 
54 template <const char* op_name>
bitwise_tensor_out(RuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)55 Tensor& bitwise_tensor_out(
56     RuntimeContext& ctx,
57     const Tensor& a,
58     const Tensor& b,
59     Tensor& out) {
60   // Common Dtype
61   ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
62 
63   // Check Common Dtype
64   ET_KERNEL_CHECK(
65       ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
66 
67   // Check Dim Order
68   ET_KERNEL_CHECK(
69       ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
70 
71   // Resize
72   ET_KERNEL_CHECK(
73       ctx,
74       resize_to_broadcast_target_size(a, b, out) == Error::Ok,
75       InvalidArgument,
76       out);
77 
78   // Compute Dtype
79   ScalarType compute_type = utils::get_compute_type(common_type);
80 
81   ET_SWITCH_INT_TYPES_AND(
82       Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
83         utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
84             BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value,
85             ctx,
86             a,
87             utils::SupportedTensorDtypes::INTB,
88             b,
89             utils::SupportedTensorDtypes::INTB,
90             out,
91             utils::SupportedTensorDtypes::REALHBBF16);
92       });
93 
94   return out;
95 }
96 
97 template <const char* op_name>
bitwise_scalar_out(RuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)98 Tensor& bitwise_scalar_out(
99     RuntimeContext& ctx,
100     const Tensor& a,
101     const Scalar& b,
102     Tensor& out) {
103   // Common Dtype
104   ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
105 
106   // Check Common Dtype
107   ET_KERNEL_CHECK(
108       ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
109 
110   // Check Dim Order
111   ET_KERNEL_CHECK(
112       ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
113 
114   // Resize
115   ET_KERNEL_CHECK(
116       ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
117 
118   // Compute Dtype
119   ScalarType compute_type = utils::get_compute_type(common_type);
120 
121   ET_SWITCH_INT_TYPES_AND(
122       Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
123         const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
124         utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
125             [val_b](const CTYPE_COMPUTE val_a) {
126               return BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value(
127                   val_a, val_b);
128             },
129             ctx,
130             a,
131             utils::SupportedTensorDtypes::INTB,
132             out,
133             utils::SupportedTensorDtypes::REALHBBF16);
134       });
135 
136   return out;
137 }
138 
139 } // namespace internal
140 } // namespace native
141 } // namespace executor
142 } // namespace torch
143