xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/math_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 namespace torch {
12 namespace executor {
13 namespace native {
14 namespace utils {
15 
16 /**
17  * Python's __floordiv__ operator is more complicated than just floor(a / b).
18  * It aims to maintain the property: a == (a // b) * b + remainder(a, b)
19  * which can otherwise fail due to rounding errors in the remainder.
20  * So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
21  * With some additional fix-ups added to the result.
22  */
23 template <
24     typename INT_T,
25     typename std::enable_if<std::is_integral<INT_T>::value, bool>::type = true>
floor_divide(INT_T a,INT_T b)26 INT_T floor_divide(INT_T a, INT_T b) {
27   const auto quot = a / b;
28   if (std::signbit(a) == std::signbit(b)) {
29     return quot;
30   }
31   const auto rem = a % b;
32   return rem ? quot - 1 : quot;
33 }
34 
35 template <
36     typename FLOAT_T,
37     typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
38         type = true>
floor_divide(FLOAT_T a,FLOAT_T b)39 FLOAT_T floor_divide(FLOAT_T a, FLOAT_T b) {
40   if (b == 0) {
41     return std::signbit(a) ? -INFINITY : INFINITY;
42   }
43   const auto mod = std::fmod(a, b);
44   auto div = (a - mod) / b;
45   if ((mod != 0) && std::signbit(b) != std::signbit(mod)) {
46     return div - 1;
47   }
48   return div;
49 }
50 
51 /**
52  * Override min/max so we can emulate PyTorch's behavior with NaN entries.
53  */
54 
55 template <
56     typename FLOAT_T,
57     typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
58         type = true>
min_override(FLOAT_T a,FLOAT_T b)59 FLOAT_T min_override(FLOAT_T a, FLOAT_T b) {
60   if (std::isnan(a)) {
61     return a;
62   } else if (std::isnan(b)) {
63     return b;
64   } else {
65     return std::min(a, b);
66   }
67 }
68 
69 template <
70     typename FLOAT_T,
71     typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
72         type = true>
max_override(FLOAT_T a,FLOAT_T b)73 FLOAT_T max_override(FLOAT_T a, FLOAT_T b) {
74   if (std::isnan(a)) {
75     return a;
76   } else if (std::isnan(b)) {
77     return b;
78   } else {
79     return std::max(a, b);
80   }
81 }
82 
83 template <
84     typename INT_T,
85     typename std::enable_if<std::is_integral<INT_T>::value, bool>::type = true>
min_override(INT_T a,INT_T b)86 INT_T min_override(INT_T a, INT_T b) {
87   return std::min(a, b);
88 }
89 
90 template <
91     typename INT_T,
92     typename std::enable_if<std::is_integral<INT_T>::value, bool>::type = true>
max_override(INT_T a,INT_T b)93 INT_T max_override(INT_T a, INT_T b) {
94   return std::max(a, b);
95 }
96 
97 template <
98     typename T,
99     typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
100         type = true>
min_override(T a,T b)101 T min_override(T a, T b) {
102   const auto float_a = static_cast<float>(a);
103   if (std::isnan(float_a)) {
104     return a;
105   }
106   const auto float_b = static_cast<float>(b);
107   if (std::isnan(float_b)) {
108     return b;
109   }
110 
111   if (float_a < float_b) {
112     return a;
113   }
114   return b;
115 }
116 
117 template <
118     typename T,
119     typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
120         type = true>
max_override(T a,T b)121 T max_override(T a, T b) {
122   const auto float_a = static_cast<float>(a);
123   if (std::isnan(float_a)) {
124     return a;
125   }
126   const auto float_b = static_cast<float>(b);
127   if (std::isnan(float_b)) {
128     return b;
129   }
130 
131   if (float_a > float_b) {
132     return a;
133   }
134   return b;
135 }
136 
137 /**
138  * There is a slight difference in how std::fmod works compared to how ATen
139  * determines remainders:
140  * The returned value of std::fmod has the same sign as x and is less than y in
141  * magnitude. (https://en.cppreference.com/w/cpp/numeric/math/fmod)
142  * On the other hand, ATen's remainder always matches the sign of y
143  * To correct this, we need to add y to the remainder when one but not both of
144  * x and y is negative and the remainder is not 0
145  */
146 
147 template <
148     typename CTYPE,
149     typename std::enable_if<std::is_floating_point<CTYPE>::value, int>::type =
150         0>
remainder_override(CTYPE a,CTYPE b)151 CTYPE remainder_override(CTYPE a, CTYPE b) {
152   float rem = std::fmod(a, b);
153   if (((a < 0) ^ (b < 0)) && rem != 0) {
154     rem += b;
155   }
156   return rem;
157 }
158 
159 template <
160     typename CTYPE,
161     typename std::enable_if<std::is_integral<CTYPE>::value, int>::type = 0>
remainder_override(CTYPE a,CTYPE b)162 CTYPE remainder_override(CTYPE a, CTYPE b) {
163   return a % b;
164 }
165 
166 } // namespace utils
167 } // namespace native
168 } // namespace executor
169 } // namespace torch
170