xref: /aosp_15_r20/external/pytorch/c10/test/util/Half_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cmath>
2 #include <limits>
3 #include <vector>
4 
5 #include <c10/util/Half.h>
6 #include <c10/util/floating_point_utils.h>
7 #include <c10/util/irange.h>
8 #include <gtest/gtest.h>
9 
10 namespace {
11 
halfbits2float(unsigned short h)12 float halfbits2float(unsigned short h) {
13   unsigned sign = ((h >> 15) & 1);
14   unsigned exponent = ((h >> 10) & 0x1f);
15   unsigned mantissa = ((h & 0x3ff) << 13);
16 
17   if (exponent == 0x1f) { /* NaN or Inf */
18     mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
19     exponent = 0xff;
20   } else if (!exponent) { /* Denorm or Zero */
21     if (mantissa) {
22       // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
23       unsigned int msb;
24       exponent = 0x71;
25       do {
26         msb = (mantissa & 0x400000);
27         mantissa <<= 1; /* normalize */
28         --exponent;
29       } while (!msb);
30       mantissa &= 0x7fffff; /* 1.mantissa is implicit */
31     }
32   } else {
33     exponent += 0x70;
34   }
35 
36   unsigned result_bit = (sign << 31) | (exponent << 23) | mantissa;
37 
38   return c10::detail::fp32_from_bits(result_bit);
39 }
40 
float2halfbits(float src)41 unsigned short float2halfbits(float src) {
42   unsigned x = c10::detail::fp32_to_bits(src);
43 
44   // NOLINTNEXTLINE(cppcoreguidelines-init-variables,cppcoreguidelines-avoid-magic-numbers)
45   unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
46   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
47   unsigned sign, exponent, mantissa;
48 
49   // Get rid of +NaN/-NaN case first.
50   if (u > 0x7f800000) {
51     return 0x7fffU;
52   }
53 
54   sign = ((x >> 16) & 0x8000);
55 
56   // Get rid of +Inf/-Inf, +0/-0.
57   if (u > 0x477fefff) {
58     return sign | 0x7c00U;
59   }
60   if (u < 0x33000001) {
61     return (sign | 0x0000);
62   }
63 
64   exponent = ((u >> 23) & 0xff);
65   mantissa = (u & 0x7fffff);
66 
67   if (exponent > 0x70) {
68     shift = 13;
69     exponent -= 0x70;
70   } else {
71     shift = 0x7e - exponent;
72     exponent = 0;
73     mantissa |= 0x800000;
74   }
75   lsb = (1 << shift);
76   lsb_s1 = (lsb >> 1);
77   lsb_m1 = (lsb - 1);
78 
79   // Round to nearest even.
80   remainder = (mantissa & lsb_m1);
81   mantissa >>= shift;
82   if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
83     ++mantissa;
84     if (!(mantissa & 0x3ff)) {
85       ++exponent;
86       mantissa = 0;
87     }
88   }
89 
90   return (sign | (exponent << 10) | mantissa);
91 }
TEST(HalfConversionTest,TestPorableConversion)92 TEST(HalfConversionTest, TestPorableConversion) {
93   std::vector<uint16_t> inputs = {
94       0,
95       0xfbff, // 1111 1011 1111 1111
96       (1 << 15 | 1),
97       0x7bff // 0111 1011 1111 1111
98   };
99   for (auto x : inputs) {
100     auto target = c10::detail::fp16_ieee_to_fp32_value(x);
101     EXPECT_EQ(halfbits2float(x), target)
102         << "Test failed for uint16 to float " << x << "\n";
103     EXPECT_EQ(
104         float2halfbits(target), c10::detail::fp16_ieee_from_fp32_value(target))
105         << "Test failed for float to uint16" << target << "\n";
106   }
107 }
108 
TEST(HalfConversion,TestNativeConversionToFloat)109 TEST(HalfConversion, TestNativeConversionToFloat) {
110   // There are only 2**16 possible values, so test them all
111   for (auto x : c10::irange(std::numeric_limits<uint16_t>::max() + 1)) {
112     auto h = c10::Half(x, c10::Half::from_bits());
113     auto f = halfbits2float(x);
114     // NaNs are not equal to each other
115     if (std::isnan(f) && std::isnan(static_cast<float>(h))) {
116       continue;
117     }
118     EXPECT_EQ(f, static_cast<float>(h)) << "Conversion error using " << x;
119   }
120 }
121 
TEST(HalfConversion,TestNativeConversionToHalf)122 TEST(HalfConversion, TestNativeConversionToHalf) {
123   auto check_conversion = [](float f) {
124     auto h = c10::Half(f);
125     auto h_bits = float2halfbits(f);
126     // NaNs are not equal to each other, just check that half is NaN
127     if (std::isnan(f)) {
128       EXPECT_TRUE(std::isnan(static_cast<float>(h)));
129     } else {
130       EXPECT_EQ(h.x, h_bits) << "Conversion error using " << f;
131     }
132   };
133 
134   for (auto x : c10::irange(std::numeric_limits<uint16_t>::max() + 1)) {
135     check_conversion(halfbits2float(x));
136   }
137   // Check a few values outside of Half range
138   check_conversion(std::numeric_limits<float>::max());
139   check_conversion(std::numeric_limits<float>::min());
140   check_conversion(std::numeric_limits<float>::epsilon());
141   check_conversion(std::numeric_limits<float>::lowest());
142 }
143 
144 } // namespace
145