1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18 namespace internal {
19
20 #define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \
21 template <typename T> \
22 T name(const T val_a, const T val_b) { \
23 return val_a op val_b; \
24 }
25
26 DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_and, &)
27 DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_or, |)
28 DEFINE_BINARY_OPERATOR_TEMPLATE(bitwise_xor, ^)
29
30 template <typename T>
31 using bitwise_fn = T (*)(const T, const T);
32
33 template <typename T, const char* op_name>
get_bitwise_fn()34 constexpr bitwise_fn<T> get_bitwise_fn() {
35 std::string_view op = op_name;
36 if (op == "bitwise_and.Tensor_out" || op == "bitwise_and.Scalar_out") {
37 return bitwise_and;
38 }
39 if (op == "bitwise_or.Tensor_out" || op == "bitwise_or.Scalar_out") {
40 return bitwise_or;
41 }
42 if (op == "bitwise_xor.Tensor_out" || op == "bitwise_xor.Scalar_out") {
43 return bitwise_xor;
44 }
45 return nullptr;
46 };
47
48 template <typename T, const char* op_name>
49 struct BitwiseFnForOp {
50 static constexpr auto value = get_bitwise_fn<T, op_name>();
51 static_assert(value != nullptr, "unknown op_name!");
52 };
53
54 template <const char* op_name>
bitwise_tensor_out(RuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)55 Tensor& bitwise_tensor_out(
56 RuntimeContext& ctx,
57 const Tensor& a,
58 const Tensor& b,
59 Tensor& out) {
60 // Common Dtype
61 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
62
63 // Check Common Dtype
64 ET_KERNEL_CHECK(
65 ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
66
67 // Check Dim Order
68 ET_KERNEL_CHECK(
69 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
70
71 // Resize
72 ET_KERNEL_CHECK(
73 ctx,
74 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
75 InvalidArgument,
76 out);
77
78 // Compute Dtype
79 ScalarType compute_type = utils::get_compute_type(common_type);
80
81 ET_SWITCH_INT_TYPES_AND(
82 Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
83 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
84 BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value,
85 ctx,
86 a,
87 utils::SupportedTensorDtypes::INTB,
88 b,
89 utils::SupportedTensorDtypes::INTB,
90 out,
91 utils::SupportedTensorDtypes::REALHBBF16);
92 });
93
94 return out;
95 }
96
97 template <const char* op_name>
bitwise_scalar_out(RuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)98 Tensor& bitwise_scalar_out(
99 RuntimeContext& ctx,
100 const Tensor& a,
101 const Scalar& b,
102 Tensor& out) {
103 // Common Dtype
104 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
105
106 // Check Common Dtype
107 ET_KERNEL_CHECK(
108 ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out);
109
110 // Check Dim Order
111 ET_KERNEL_CHECK(
112 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
113
114 // Resize
115 ET_KERNEL_CHECK(
116 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
117
118 // Compute Dtype
119 ScalarType compute_type = utils::get_compute_type(common_type);
120
121 ET_SWITCH_INT_TYPES_AND(
122 Bool, compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
123 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
124 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
125 [val_b](const CTYPE_COMPUTE val_a) {
126 return BitwiseFnForOp<CTYPE_COMPUTE, op_name>::value(
127 val_a, val_b);
128 },
129 ctx,
130 a,
131 utils::SupportedTensorDtypes::INTB,
132 out,
133 utils::SupportedTensorDtypes::REALHBBF16);
134 });
135
136 return out;
137 }
138
139 } // namespace internal
140 } // namespace native
141 } // namespace executor
142 } // namespace torch
143