xref: /aosp_15_r20/external/armnn/src/armnnUtils/BFloat16.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ostream>
9 #include <cmath>
10 #include <cstring>
11 #include <stdint.h>
12 
13 namespace armnn
14 {
15 class BFloat16
16 {
17 public:
BFloat16()18     BFloat16()
19     : m_Value(0)
20     {}
21 
22     BFloat16(const BFloat16& v) = default;
23 
BFloat16(uint16_t v)24     explicit BFloat16(uint16_t v)
25     : m_Value(v)
26     {}
27 
BFloat16(float v)28     explicit BFloat16(float v)
29     {
30         m_Value = Float32ToBFloat16(v).Val();
31     }
32 
operator float() const33     operator float() const
34     {
35         return ToFloat32();
36     }
37 
38     BFloat16& operator=(const BFloat16& other) = default;
39 
operator =(float v)40     BFloat16& operator=(float v)
41     {
42         m_Value = Float32ToBFloat16(v).Val();
43         return *this;
44     }
45 
operator ==(const BFloat16 & r) const46     bool operator==(const BFloat16& r) const
47     {
48         return m_Value == r.Val();
49     }
50 
Float32ToBFloat16(const float v)51     static BFloat16 Float32ToBFloat16(const float v)
52     {
53         if (std::isnan(v))
54         {
55             return Nan();
56         }
57         else
58         {
59             // Round value to the nearest even
60             // Float32
61             // S EEEEEEEE MMMMMMLRMMMMMMMMMMMMMMM
62             // BFloat16
63             // S EEEEEEEE MMMMMML
64             // LSB (L): Least significat bit of BFloat16 (last bit of the Mantissa of BFloat16)
65             // R: Rounding bit
66             // LSB = 0, R = 0 -> round down
67             // LSB = 1, R = 0 -> round down
68             // LSB = 0, R = 1, all the rest = 0 -> round down
69             // LSB = 1, R = 1 -> round up
70             // LSB = 0, R = 1 -> round up
71             const uint32_t* u32 = reinterpret_cast<const uint32_t*>(&v);
72             uint16_t u16 = static_cast<uint16_t>(*u32 >> 16u);
73             // Mark the LSB
74             const uint16_t lsb = u16 & 0x0001;
75             // Mark the error to be truncate (the rest of 16 bits of FP32)
76             const uint16_t error = static_cast<uint16_t>((*u32 & 0x0000FFFF));
77             if ((error > 0x8000 || (error == 0x8000 && lsb == 1)))
78             {
79                 u16++;
80             }
81             BFloat16 b(u16);
82             return b;
83         }
84     }
85 
ToFloat32() const86     float ToFloat32() const
87     {
88         const uint32_t u32 = static_cast<uint32_t>(m_Value << 16u);
89         float f32;
90         static_assert(sizeof u32 == sizeof f32, "");
91         std::memcpy(&f32, &u32, sizeof u32);
92         return f32;
93     }
94 
Val() const95     uint16_t Val() const
96     {
97         return m_Value;
98     }
99 
Max()100     static BFloat16 Max()
101     {
102         uint16_t max = 0x7F7F;
103         return BFloat16(max);
104     }
105 
Nan()106     static BFloat16 Nan()
107     {
108         uint16_t nan = 0x7FC0;
109         return BFloat16(nan);
110     }
111 
Inf()112     static BFloat16 Inf()
113     {
114         uint16_t infVal = 0x7F80;
115         return BFloat16(infVal);
116     }
117 
118 private:
119     uint16_t m_Value;
120 };
121 
operator <<(std::ostream & os,const BFloat16 & b)122 inline std::ostream& operator<<(std::ostream& os, const BFloat16& b)
123 {
124     os << b.ToFloat32() << "(0x" << std::hex << b.Val() << ")";
125     return os;
126 }
127 
128 } //namespace armnn
129