xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/bit_cast.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // BitCast is an extension of std::bit_cast/absl::bit_cast. Whereas those
17 // functions require trivially copyable source and destination types, the
18 // present function template may be specialized for additional types that
19 // do not satisfy that triviality property, but that have alternative ways
20 // of accessing their underlying representation.
21 //
22 // Concretely, we provide specializations for the "custom floating point types"
23 // Eigen::half and tensorflow::bfloat16. Those types are effectively stored as
24 // a sequence of bits, but the classes are not trivially copyable.
25 
26 #ifndef TENSORFLOW_COMPILER_XLA_BIT_CAST_H_
27 #define TENSORFLOW_COMPILER_XLA_BIT_CAST_H_
28 
29 #include "absl/base/casts.h"
30 #include "third_party/eigen3/Eigen/Core"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/platform/bfloat16.h"
33 
34 namespace xla {
35 
36 template <typename T, typename U>
BitCast(U src)37 T BitCast(U src) {
38   static_assert(sizeof(T) == sizeof(U), "sizes don't match");
39   // We would like to check std::is_trivially_copyable here, but there's no
40   // reliable implementation of that available to us.
41   return absl::bit_cast<T>(src);
42 }
43 
44 template <>
45 inline tensorflow::bfloat16 BitCast<tensorflow::bfloat16, uint16_t>(
46     uint16_t src) {
47   return Eigen::numext::bit_cast<tensorflow::bfloat16>(src);
48 }
49 
50 template <>
51 inline uint16_t BitCast<uint16_t, tensorflow::bfloat16>(
52     tensorflow::bfloat16 src) {
53   return Eigen::numext::bit_cast<uint16_t>(src);
54 }
55 
56 template <>
57 inline Eigen::half BitCast<Eigen::half, uint16_t>(uint16_t src) {
58   return Eigen::numext::bit_cast<Eigen::half>(src);
59 }
60 
61 template <>
62 inline uint16_t BitCast<uint16_t, Eigen::half>(Eigen::half src) {
63   return Eigen::numext::bit_cast<uint16_t>(src);
64 }
65 
66 }  // namespace xla
67 
68 #endif  // TENSORFLOW_COMPILER_XLA_BIT_CAST_H_
69