1 #include <cmath>
2 #include <limits>
3 #include <vector>
4
5 #include <c10/util/Half.h>
6 #include <c10/util/floating_point_utils.h>
7 #include <c10/util/irange.h>
8 #include <gtest/gtest.h>
9
10 namespace {
11
halfbits2float(unsigned short h)12 float halfbits2float(unsigned short h) {
13 unsigned sign = ((h >> 15) & 1);
14 unsigned exponent = ((h >> 10) & 0x1f);
15 unsigned mantissa = ((h & 0x3ff) << 13);
16
17 if (exponent == 0x1f) { /* NaN or Inf */
18 mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
19 exponent = 0xff;
20 } else if (!exponent) { /* Denorm or Zero */
21 if (mantissa) {
22 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
23 unsigned int msb;
24 exponent = 0x71;
25 do {
26 msb = (mantissa & 0x400000);
27 mantissa <<= 1; /* normalize */
28 --exponent;
29 } while (!msb);
30 mantissa &= 0x7fffff; /* 1.mantissa is implicit */
31 }
32 } else {
33 exponent += 0x70;
34 }
35
36 unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;
37
38 return c10::detail::fp32_from_bits(result_bit);
39 }
40
float2halfbits(float src)41 unsigned short float2halfbits(float src) {
42 unsigned x = c10::detail::fp32_to_bits(src);
43
44 // NOLINTNEXTLINE(cppcoreguidelines-init-variables,cppcoreguidelines-avoid-magic-numbers)
45 unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
46 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
47 unsigned sign, exponent, mantissa;
48
49 // Get rid of +NaN/-NaN case first.
50 if (u > 0x7f800000) {
51 return 0x7fffU;
52 }
53
54 sign = ((x >> 16) & 0x8000);
55
56 // Get rid of +Inf/-Inf, +0/-0.
57 if (u > 0x477fefff) {
58 return sign | 0x7c00U;
59 }
60 if (u < 0x33000001) {
61 return (sign | 0x0000);
62 }
63
64 exponent = ((u >> 23) & 0xff);
65 mantissa = (u & 0x7fffff);
66
67 if (exponent > 0x70) {
68 shift = 13;
69 exponent -= 0x70;
70 } else {
71 shift = 0x7e - exponent;
72 exponent = 0;
73 mantissa |= 0x800000;
74 }
75 lsb = (1 << shift);
76 lsb_s1 = (lsb >> 1);
77 lsb_m1 = (lsb - 1);
78
79 // Round to nearest even.
80 remainder = (mantissa & lsb_m1);
81 mantissa >>= shift;
82 if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
83 ++mantissa;
84 if (!(mantissa & 0x3ff)) {
85 ++exponent;
86 mantissa = 0;
87 }
88 }
89
90 return (sign | (exponent << 10) | mantissa);
91 }
TEST(HalfConversionTest,TestPorableConversion)92 TEST(HalfConversionTest, TestPorableConversion) {
93 std::vector<uint16_t> inputs = {
94 0,
95 0xfbff, // 1111 1011 1111 1111
96 (1 << 15 | 1),
97 0x7bff // 0111 1011 1111 1111
98 };
99 for (auto x : inputs) {
100 auto target = c10::detail::fp16_ieee_to_fp32_value(x);
101 EXPECT_EQ(halfbits2float(x), target)
102 << "Test failed for uint16 to float " << x << "\n";
103 EXPECT_EQ(
104 float2halfbits(target), c10::detail::fp16_ieee_from_fp32_value(target))
105 << "Test failed for float to uint16" << target << "\n";
106 }
107 }
108
TEST(HalfConversion,TestNativeConversionToFloat)109 TEST(HalfConversion, TestNativeConversionToFloat) {
110 // There are only 2**16 possible values, so test them all
111 for (auto x : c10::irange(std::numeric_limits<uint16_t>::max() + 1)) {
112 auto h = c10::Half(x, c10::Half::from_bits());
113 auto f = halfbits2float(x);
114 // NaNs are not equal to each other
115 if (std::isnan(f) && std::isnan(static_cast<float>(h))) {
116 continue;
117 }
118 EXPECT_EQ(f, static_cast<float>(h)) << "Conversion error using " << x;
119 }
120 }
121
TEST(HalfConversion,TestNativeConversionToHalf)122 TEST(HalfConversion, TestNativeConversionToHalf) {
123 auto check_conversion = [](float f) {
124 auto h = c10::Half(f);
125 auto h_bits = float2halfbits(f);
126 // NaNs are not equal to each other, just check that half is NaN
127 if (std::isnan(f)) {
128 EXPECT_TRUE(std::isnan(static_cast<float>(h)));
129 } else {
130 EXPECT_EQ(h.x, h_bits) << "Conversion error using " << f;
131 }
132 };
133
134 for (auto x : c10::irange(std::numeric_limits<uint16_t>::max() + 1)) {
135 check_conversion(halfbits2float(x));
136 }
137 // Check a few values outside of Half range
138 check_conversion(std::numeric_limits<float>::max());
139 check_conversion(std::numeric_limits<float>::min());
140 check_conversion(std::numeric_limits<float>::epsilon());
141 check_conversion(std::numeric_limits<float>::lowest());
142 }
143
144 } // namespace
145