xref: /aosp_15_r20/external/tensorflow/tensorflow/python/lib/core/float8_e4m3b11.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/lib/core/float8_e4m3b11.h"
17 
18 #include <stdio.h>
19 
20 namespace tensorflow {
21 
float_to_float8_e4m3b11(float v)22 uint8_t float_to_float8_e4m3b11(float v) {
23   static_assert(sizeof(float) == sizeof(uint32_t), "Invalid");
24   uint32_t tmp = *reinterpret_cast<uint32_t*>(&v);
25 
26   uint32_t sign = (tmp & 0x80000000) >> 24;
27   uint32_t exponent = (tmp >> 23) & 0xff;
28   uint32_t mantissa = tmp & 0x7fffff;
29   // subnormals
30   if (exponent < 127 - 10) {
31     if (exponent < 127 - 14) {
32       return 0x00;
33     }
34     uint32_t shifted_mantissa =
35         (0x800000 | mantissa) >> (10 - ((exponent - 127)));
36     if (shifted_mantissa == 0) return 0x00;
37     return sign | shifted_mantissa;
38   }
39   if (exponent > 127 + 4) {
40     if (exponent == 255 && mantissa != 0) {
41       return 0x80;  // nan.
42     }
43     return 0x7f | sign;
44   }
45   exponent = exponent - (127 - 11);
46   uint8_t result = sign | (exponent << 3) | (mantissa >> 20);
47   if (result == 0x80) {
48     result = 0;
49   }
50   return result;
51 }
52 
clz_uint32(uint32_t x)53 static uint32_t clz_uint32(uint32_t x) {
54 #ifdef __GNUC__
55   return __builtin_clz(x);
56 #else
57   uint32_t out = 32;
58   while (x != 0) {
59     x = x >> 1;
60     out -= 1;
61   }
62   return out;
63 #endif
64 }
65 
float8_e4m3b11_to_float(uint8_t v)66 float float8_e4m3b11_to_float(uint8_t v) {
67   if (v == 0x80) {
68     return NAN;
69   }
70   if (v == 0) {
71     return 0;
72   }
73   uint32_t sign = (0x80 & v) << 24;
74   uint32_t exponent = (((v & 0x78) >> 3) + (127 - 11));
75   uint32_t mantissa = (v & 0x7) << 20;
76   // subnormals
77   if ((v & 0x78) == 0) {
78     uint32_t nzeros = clz_uint32(v & 0x7);
79     mantissa = ((v & 0x7) << (nzeros - 29 + 21)) & (0x3 << 21);
80     uint32_t tmp = sign | ((0x72 - nzeros + 31) << 23) | mantissa;
81     return *reinterpret_cast<float*>(&tmp);
82   }
83   uint32_t tmp = sign | (exponent << 23) | mantissa;
84   return *reinterpret_cast<float*>(&tmp);
85 }
86 
87 }  // namespace tensorflow
88