1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <cmath>
10
11 #include <executorch/kernels/portable/cpu/scalar_utils.h>
12 #include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13 #include <executorch/runtime/kernel/kernel_includes.h>
14
15 namespace torch {
16 namespace executor {
17 namespace native {
18
fmod_Tensor_out(KernelRuntimeContext & ctx,const Tensor & a,const Tensor & b,Tensor & out)19 Tensor& fmod_Tensor_out(
20 KernelRuntimeContext& ctx,
21 const Tensor& a,
22 const Tensor& b,
23 Tensor& out) {
24 // Common Dtype
25 ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type());
26
27 // Check Common Dtype
28 ET_KERNEL_CHECK(
29 ctx,
30 (canCast(common_type, out.scalar_type()) &&
31 common_type != ScalarType::Bool),
32 InvalidArgument,
33 out);
34
35 // Check Dim Order
36 ET_KERNEL_CHECK(
37 ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);
38
39 // Resize
40 ET_KERNEL_CHECK(
41 ctx,
42 resize_to_broadcast_target_size(a, b, out) == Error::Ok,
43 InvalidArgument,
44 out);
45
46 // Compute Dtype
47 ScalarType compute_type = utils::get_compute_type(common_type);
48 if (compute_type != ScalarType::Float) {
49 compute_type = ScalarType::Double;
50 }
51
52 // @lint-ignore CLANGTIDY facebook-hte-CArray
53 static constexpr const char op_name[] = "fmod.Tensor_out";
54
55 bool div_by_zero_error = false;
56
57 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
58 utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
59 [&div_by_zero_error](
60 const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
61 CTYPE_COMPUTE value = 0;
62 if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
63 if (val_b == 0) {
64 div_by_zero_error = true;
65 return value;
66 }
67 }
68 value = std::fmod(val_a, val_b);
69 return value;
70 },
71 ctx,
72 a,
73 utils::SupportedTensorDtypes::REALHBBF16,
74 b,
75 utils::SupportedTensorDtypes::REALHBBF16,
76 out,
77 utils::SupportedTensorDtypes::REALHBF16);
78 });
79
80 ET_KERNEL_CHECK_MSG(
81 ctx,
82 !div_by_zero_error,
83 InvalidArgument,
84 out,
85 "Fmod operation encountered integer division by zero");
86
87 return out;
88 }
89
fmod_Scalar_out(KernelRuntimeContext & ctx,const Tensor & a,const Scalar & b,Tensor & out)90 Tensor& fmod_Scalar_out(
91 KernelRuntimeContext& ctx,
92 const Tensor& a,
93 const Scalar& b,
94 Tensor& out) {
95 // Common Dtype
96 ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b);
97
98 // Check Common Dtype
99 ET_KERNEL_CHECK(
100 ctx,
101 (canCast(common_type, out.scalar_type()) &&
102 common_type != ScalarType::Bool),
103 InvalidArgument,
104 out);
105
106 // Check for intergral division by zero
107 ET_KERNEL_CHECK_MSG(
108 ctx,
109 !(executorch::runtime::isIntegralType(common_type, true) &&
110 utils::scalar_to<double>(b) == 0),
111 InvalidArgument,
112 out,
113 "Fmod operation encountered integer division by zero");
114
115 // Check Dim Order
116 ET_KERNEL_CHECK(
117 ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);
118
119 // Resize
120 ET_KERNEL_CHECK(
121 ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);
122
123 // Compute Dtype
124 ScalarType compute_type = utils::get_compute_type(common_type);
125 if (compute_type != ScalarType::Float) {
126 compute_type = ScalarType::Double;
127 }
128
129 // @lint-ignore CLANGTIDY facebook-hte-CArray
130 static constexpr const char op_name[] = "fmod.Scalar_out";
131
132 ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
133 const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
134 utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
135 [val_b](const CTYPE_COMPUTE val_a) {
136 CTYPE_COMPUTE value = std::fmod(val_a, val_b);
137 return value;
138 },
139 ctx,
140 a,
141 utils::SupportedTensorDtypes::REALHBBF16,
142 out,
143 utils::SupportedTensorDtypes::REALHBF16);
144 });
145
146 return out;
147 }
148
149 } // namespace native
150 } // namespace executor
151 } // namespace torch
152