1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 namespace torch {
12 namespace executor {
13 namespace native {
14 namespace utils {
15
16 /**
17 * Python's __floordiv__ operator is more complicated than just floor(a / b).
18 * It aims to maintain the property: a == (a // b) * b + remainder(a, b)
19 * which can otherwise fail due to rounding errors in the remainder.
20 * So, instead it is calculated as: a // b = (a - remainder(a, b)) / b
21 * With some additional fix-ups added to the result.
22 */
23 template <
24 typename INT_T,
25 typename std::enable_if<std::is_integral<INT_T>::value, bool>::type = true>
floor_divide(INT_T a,INT_T b)26 INT_T floor_divide(INT_T a, INT_T b) {
27 const auto quot = a / b;
28 if (std::signbit(a) == std::signbit(b)) {
29 return quot;
30 }
31 const auto rem = a % b;
32 return rem ? quot - 1 : quot;
33 }
34
35 template <
36 typename FLOAT_T,
37 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
38 type = true>
floor_divide(FLOAT_T a,FLOAT_T b)39 FLOAT_T floor_divide(FLOAT_T a, FLOAT_T b) {
40 if (b == 0) {
41 return std::signbit(a) ? -INFINITY : INFINITY;
42 }
43 const auto mod = std::fmod(a, b);
44 auto div = (a - mod) / b;
45 if ((mod != 0) && std::signbit(b) != std::signbit(mod)) {
46 return div - 1;
47 }
48 return div;
49 }
50
51 /**
52 * Override min/max so we can emulate PyTorch's behavior with NaN entries.
53 */
54
55 template <
56 typename FLOAT_T,
57 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
58 type = true>
min_override(FLOAT_T a,FLOAT_T b)59 FLOAT_T min_override(FLOAT_T a, FLOAT_T b) {
60 if (std::isnan(a)) {
61 return a;
62 } else if (std::isnan(b)) {
63 return b;
64 } else {
65 return std::min(a, b);
66 }
67 }
68
69 template <
70 typename FLOAT_T,
71 typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
72 type = true>
max_override(FLOAT_T a,FLOAT_T b)73 FLOAT_T max_override(FLOAT_T a, FLOAT_T b) {
74 if (std::isnan(a)) {
75 return a;
76 } else if (std::isnan(b)) {
77 return b;
78 } else {
79 return std::max(a, b);
80 }
81 }
82
83 template <
84 typename INT_T,
85 typename std::enable_if<std::is_integral<INT_T>::value, bool>::type = true>
min_override(INT_T a,INT_T b)86 INT_T min_override(INT_T a, INT_T b) {
87 return std::min(a, b);
88 }
89
90 template <
91 typename INT_T,
92 typename std::enable_if<std::is_integral<INT_T>::value, bool>::type = true>
max_override(INT_T a,INT_T b)93 INT_T max_override(INT_T a, INT_T b) {
94 return std::max(a, b);
95 }
96
97 template <
98 typename T,
99 typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
100 type = true>
min_override(T a,T b)101 T min_override(T a, T b) {
102 const auto float_a = static_cast<float>(a);
103 if (std::isnan(float_a)) {
104 return a;
105 }
106 const auto float_b = static_cast<float>(b);
107 if (std::isnan(float_b)) {
108 return b;
109 }
110
111 if (float_a < float_b) {
112 return a;
113 }
114 return b;
115 }
116
117 template <
118 typename T,
119 typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
120 type = true>
max_override(T a,T b)121 T max_override(T a, T b) {
122 const auto float_a = static_cast<float>(a);
123 if (std::isnan(float_a)) {
124 return a;
125 }
126 const auto float_b = static_cast<float>(b);
127 if (std::isnan(float_b)) {
128 return b;
129 }
130
131 if (float_a > float_b) {
132 return a;
133 }
134 return b;
135 }
136
137 /**
138 * There is a slight difference in how std::fmod works compared to how ATen
139 * determines remainders:
140 * The returned value of std::fmod has the same sign as x and is less than y in
141 * magnitude. (https://en.cppreference.com/w/cpp/numeric/math/fmod)
142 * On the other hand, ATen's remainder always matches the sign of y
143 * To correct this, we need to add y to the remainder when one but not both of
144 * x and y is negative and the remainder is not 0
145 */
146
147 template <
148 typename CTYPE,
149 typename std::enable_if<std::is_floating_point<CTYPE>::value, int>::type =
150 0>
remainder_override(CTYPE a,CTYPE b)151 CTYPE remainder_override(CTYPE a, CTYPE b) {
152 float rem = std::fmod(a, b);
153 if (((a < 0) ^ (b < 0)) && rem != 0) {
154 rem += b;
155 }
156 return rem;
157 }
158
159 template <
160 typename CTYPE,
161 typename std::enable_if<std::is_integral<CTYPE>::value, int>::type = 0>
remainder_override(CTYPE a,CTYPE b)162 CTYPE remainder_override(CTYPE a, CTYPE b) {
163 return a % b;
164 }
165
166 } // namespace utils
167 } // namespace native
168 } // namespace executor
169 } // namespace torch
170