xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/hlo_casting_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Casting utility functions for HLO instructions.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
20 
21 #include <type_traits>
22 
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace xla {
27 
28 template <class T>
29 using EnableIfDerivedFromHlo =
30     typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type;
31 
32 // Casts an HloInstruction pointer to one of its subclasses, dies if argument is
33 // nullptr or runtime information does not match.
34 //
35 // Similar to LLVM's cast.
36 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
Cast(const HloInstruction * instruction)37 const T* Cast(const HloInstruction* instruction) {
38   CHECK(instruction != nullptr);
39   CHECK(T::ClassOf(instruction))
40       << "Invalid HloInstruction casting. Destination type: "
41       << typeid(T).name() << ". Instruction: " << instruction->name();
42   const T* casted = static_cast<const T*>(instruction);
43 #ifndef NDEBUG
44   const T* dynamic_casted = dynamic_cast<const T*>(instruction);
45   CHECK(dynamic_casted != nullptr)
46       << "Invalid HloInstruction casting. Destination type: "
47       << typeid(T).name() << ". Instruction: " << instruction->name();
48 #endif
49   return casted;
50 }
51 
52 // Non-const overload of Cast.
53 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
Cast(HloInstruction * instruction)54 T* Cast(HloInstruction* instruction) {
55   return const_cast<T*>(
56       Cast<T>(const_cast<const HloInstruction*>(instruction)));
57 }
58 
59 // Works just like the Cast, except that it allows for a null pointer as an
60 // argument which it then propagates.
61 //
62 // Similar to LLVM's cast_or_null.
63 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
CastOrNull(const HloInstruction * instruction)64 const T* CastOrNull(const HloInstruction* instruction) {
65   return instruction != nullptr ? Cast<T>(instruction) : nullptr;
66 }
67 
68 // Non-const overload of CastOrNull.
69 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
CastOrNull(HloInstruction * instruction)70 T* CastOrNull(HloInstruction* instruction) {
71   return const_cast<T*>(
72       CastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
73 }
74 
75 // Casts an HloInstruction pointer to one of its subclasses, dies if argument is
76 // nullptr, returns nullptr if runtime information does not match.
77 //
78 // Similar to LLVM's dyn_cast.
79 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCast(const HloInstruction * instruction)80 const T* DynCast(const HloInstruction* instruction) {
81   CHECK(instruction != nullptr);
82   const T* casted =
83       T::ClassOf(instruction) ? static_cast<const T*>(instruction) : nullptr;
84 #ifndef NDEBUG
85   CHECK_EQ(casted, dynamic_cast<const T*>(instruction));
86 #endif
87   return casted;
88 }
89 
90 // Non-const overload of DynCast.
91 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCast(HloInstruction * instruction)92 T* DynCast(HloInstruction* instruction) {
93   return const_cast<T*>(
94       DynCast<T>(const_cast<const HloInstruction*>(instruction)));
95 }
96 
97 // Works just like the DynCast, except that it allows for a null pointer as an
98 // argument which it then propagates.
99 //
100 // Similar to LLVM's dyn_cast_or_null.
101 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCastOrNull(const HloInstruction * instruction)102 const T* DynCastOrNull(const HloInstruction* instruction) {
103   return instruction != nullptr ? DynCast<T>(instruction) : nullptr;
104 }
105 
106 // Non-const overload of DynCastOrNull.
107 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCastOrNull(HloInstruction * instruction)108 T* DynCastOrNull(HloInstruction* instruction) {
109   return const_cast<T*>(
110       DynCastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
111 }
112 
113 }  // namespace xla
114 
115 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
116