xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
18 
19 #include <stdint.h>
20 
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "llvm/IR/BasicBlock.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/Module.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
34 #include "tensorflow/compiler/xla/literal.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
37 #include "tensorflow/compiler/xla/types.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 
40 namespace llvm {
41 class FastMathFlags;
42 class TargetOptions;
43 };  // namespace llvm
44 
45 namespace xla {
46 namespace llvm_ir {
47 
48 // Dump the given LLVM entity to a string. This works for Types and Values.
49 template <typename T>
DumpToString(const T & entity)50 std::string DumpToString(const T& entity) {
51   std::string buffer_string;
52   llvm::raw_string_ostream ostream(buffer_string);
53   entity.print(ostream);
54   ostream.flush();
55   return buffer_string;
56 }
57 
58 // Same as above, except that const T& does not work well with MILR because the
59 // print methods are not const.
60 template <typename T>
DumpToString(T & entity)61 std::string DumpToString(T& entity) {
62   std::string buffer_string;
63   llvm::raw_string_ostream ostream(buffer_string);
64   entity.print(ostream);
65   ostream.flush();
66   return buffer_string;
67 }
68 
69 // Dump the given LLVM module to a string. This requires a function distinct
70 // from DumpToString because the signatures of the print() methods for Values
71 // and Modules are slightly different.
72 std::string DumpModuleToString(const llvm::Module& module);
73 
74 // Constructs a human-friendly name from the given inputs.  The result is
75 // suitable for use as an llvm::Value's name.
76 //
77 // This is equivalent to
78 //
79 //   - changing the HloInstruction* to its name() (if we called that overload),
80 //   - joining all of the nonempty inputs by '.', and then
81 //   - removing all '%'s.
82 //
83 std::string IrName(absl::string_view a);
84 std::string IrName(absl::string_view a, absl::string_view b);
85 std::string IrName(const HloInstruction* a, absl::string_view b = "");
86 
87 // Removes special characters from a function name.
88 //
89 // Note that this can cause different inputs to map to the same output, so after
90 // sanitizing a function name, you must run it through a uniquer.
91 std::string SanitizeFunctionName(std::string function_name);
92 
93 // Emits a call to the specified intrinsic with the given operands. Overloaded
94 // intrinsics (for example, "minnum") must include a type in overloaded_types
95 // for each overloaded type. Typically, overloaded intrinsics have only a single
96 // overloaded type.
97 llvm::CallInst* EmitCallToIntrinsic(
98     llvm::Intrinsic::ID intrinsic_id, absl::Span<llvm::Value* const> operands,
99     absl::Span<llvm::Type* const> overloaded_types, llvm::IRBuilder<>* b,
100     absl::string_view name = "");
101 
102 // Emit float max. Emit maxnum intrinsic is fast math is disabled, or
103 // fcmp+select otherwise
104 llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value,
105                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
106                           absl::string_view name = "");
107 
108 // Emit float min. Emit minnum intrinsic is fast math is disabled, or
109 // fcmp+select otherwise
110 llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value,
111                           llvm::IRBuilder<>* b, bool enable_fast_min_max,
112                           absl::string_view name = "");
113 
114 // Convenience methods for emitting a GEP instruction that indexes into a buffer
115 // (1-dimensional array), equivalent to array[index]. The element type of the
116 // array must be explicitly passed in.  The int64_t index overload
117 // wraps the index in a i64 llvm::Value.
118 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
119                                    llvm::Value* index, llvm::IRBuilder<>* b);
120 llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Type* element_type,
121                                    int64_t index, llvm::IRBuilder<>* b);
122 
123 // Returns the LLVM type which represents the given XLA primitive type.
124 llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
125                                   llvm::Module* module);
126 
127 // Returns the type size in bits. If "type" is a struct, it must be packed.
128 int GetSizeInBits(llvm::Type* type);
129 
130 // Returns the LLVM type which represents the given XLA shape. For example,
131 // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
132 llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module);
133 
134 // Returns a value that represents a pointer to a global string constant that
135 // encodes the shape as a serialized protobuf.
136 StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
137                                                          int32_t* shape_size,
138                                                          llvm::IRBuilder<>* b);
139 
140 // Converts a given literal to an IR Constant. Literals have known constant
141 // values at IR emission time.
142 llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
143                                            llvm::Module* module);
144 
145 // Allocates a tile of shared memory.
146 llvm::GlobalVariable* AllocateSharedMemoryTile(llvm::Module* module,
147                                                llvm::Type* tile_type,
148                                                absl::string_view name);
149 
150 // Inserts an allocate of the requested type at the entry point of the
151 // function that the builder is currently building. The insert point
152 // of the builder is set to the same place after calling this function
153 // as before.
154 //
155 // This can be useful to avoid e.g. executing an alloca every time
156 // through a loop.
157 llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
158                                             absl::string_view name,
159                                             llvm::IRBuilder<>* b,
160                                             int alignment = 0);
161 
162 // As EmitAllocaAtFunctionEntry, but allocates element_count entries
163 // instead of a single element.
164 llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(llvm::Type* type,
165                                                      llvm::Value* element_count,
166                                                      absl::string_view name,
167                                                      llvm::IRBuilder<>* b,
168                                                      int alignment = 0);
169 
170 // Creates a basic block with the same context and function as for the
171 // builder. Inserts at the end of the function if insert_before is
172 // null.
173 llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
174                                    absl::string_view name,
175                                    llvm::IRBuilder<>* b);
176 
177 // Struct with data on a conditional branch in a diamond shape created
178 // via EmitIfThenElse.
179 struct LlvmIfData {
180   // The block that has the conditional branch.
181   llvm::BasicBlock* if_block;
182 
183   // The block that is executed if the condition is true.
184   llvm::BasicBlock* true_block;
185 
186   // The block that is executed if the condition is false.
187   llvm::BasicBlock* false_block;
188 
189   // The block that follows after both the true_block and the
190   // false_block.
191   llvm::BasicBlock* after_block;
192 };
193 
194 // Inserts a diamond-shaped if-then-else construct at the current
195 // insertion point of the builder. This involves splitting the current
196 // block into two blocks, at the insertion point, and introducing a
197 // true-block and a false-block that connect the two split pieces. The
198 // true-block is executed if the condition parameter evaluates to true
199 // and otherwise the false-block is executed. If `emit_else` is false,
200 // it jumps to the after-block rather than the false-block if the
201 // condition is false, and the returned `false_block` is null.
202 //
203 // Currently the insertion point of the builder must be a well-formed
204 // block with a terminator. If you need to use this for a
205 // non-terminated block, just make the function able to do that too.
206 LlvmIfData EmitIfThenElse(llvm::Value* condition, absl::string_view name,
207                           llvm::IRBuilder<>* b, bool emit_else = true);
208 
209 // Emits a compare operation between "lhs" and "rhs" with the given predicate,
210 // and then converts the result to i8 so that it is addressable.
211 llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
212                             llvm::Value* lhs, llvm::Value* rhs,
213                             llvm::IRBuilder<>* b, absl::string_view name = "");
214 
215 // Emits a call that logs the given value with the given tag as a prefix.
216 // The provided tag and value are passed to a runtime logging call that is
217 // embedded in this translation unit when the emitted code is executed.
218 //
219 // This can be very useful for debugging generated programs in short order when
220 // developing new generated routines.
221 //
222 // Precondition: value must be an int64_t.
223 // Precondition: tag must be a stable pointer for the lifetime of the generated
224 // program (the constant pointer is burned in to the program).
225 void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* b);
226 
227 // Adds alignment metadata to a load instruction using the given alignment.
228 // The alignment refers to the result of the load, not the load itself.
229 void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment);
230 
231 // Adds dereferenceable metadata to a load instruction using the given
232 // the number of dereferenceable bytes.
233 // Dereferenceable refers to the result of the load, not the load itself.
234 void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
235                                        uint64_t dereferenceable_bytes);
236 
237 // Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience.
238 llvm::Instruction* AddRangeMetadata(int32_t lower, int32_t upper,
239                                     llvm::Instruction* inst);
240 
241 void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
242 
243 void SetToLastInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
244 
245 // Create a bitwise rotation of `rotand` by `rotor`.
246 llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
247                        llvm::IRBuilder<>* builder);
248 
249 // Returns the number of bytes within the shape.
250 int64_t ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
251 
252 // Gets an llvm::FastMathFlags that reflects the settings in the given
253 // module config.
254 llvm::FastMathFlags GetCpuFastMathFlags(const HloModuleConfig& module_config);
255 
256 // Computes a conservative union of the metadata in "a" and "b".  For
257 // aliasing-related metadata, this means the result can be applied to
258 // instructions whose aliasing relationship can be described either by "a" *or*
259 // by "b".
260 std::map<int, llvm::MDNode*> MergeMetadata(
261     llvm::LLVMContext* context, const std::map<int, llvm::MDNode*>& a,
262     const std::map<int, llvm::MDNode*>& b);
263 
264 // Dumps out `llvm_module` to the path specified in DebugOptions, if dumping is
265 // enabled for the given HLO module.
266 //
267 // A sanitized version of `hlo_module_name` is incorporated into the file name.
268 // If `optimized` is true then a suffix of "-with-opt.ll" is used, else a suffix
269 // of "-no-opt.ll" is used.
270 void DumpIrIfEnabled(const HloModule& hlo_module,
271                      const llvm::Module& llvm_module, bool optimized,
272                      absl::string_view filename_suffix = "");
273 
274 llvm::Function* CreateCpuFunction(llvm::FunctionType* function_type,
275                                   llvm::GlobalValue::LinkageTypes linkage,
276                                   const HloModuleConfig& module_config,
277                                   absl::string_view name, llvm::Module* module);
278 
279 // Zero-extends two 32-bit values to 64 bits, multiplies them, and returns the
280 // result as a pair of (low 32 bits, high 32 bits).
281 std::pair<llvm::Value*, llvm::Value*> UMulLowHigh32(llvm::IRBuilder<>* b,
282                                                     llvm::Value* src0,
283                                                     llvm::Value* src1);
284 // Splits the 64-bit integer value into its high and low 32 bits.
285 std::pair<llvm::Value*, llvm::Value*> SplitInt64ToInt32s(
286     llvm::IRBuilder<>* b, llvm::Value* value_64bits);
287 
288 // Checks whether a global variable is already created to represent the state
289 // of a random number generator. If not, creates such a variable. Returns the
290 // global variable.
291 llvm::GlobalVariable* GetOrCreateVariableRngState(llvm::Module* module,
292                                                   llvm::IRBuilder<>* b);
293 
294 // Adds a delta value to the global state variable and return the old value of
295 // the variable.
296 llvm::Value* RngGetAndUpdateState(uint64_t delta, llvm::Module* module,
297                                   llvm::IRBuilder<>* b);
298 
299 // Gets the LLVM address space that should be used for global variables (e.g.
300 // XLA's rng state).
301 unsigned GetGlobalMemoryAddressSpace();
302 
303 // Emits a block which does "return void". Leaves the insert point as is.
304 llvm::BasicBlock* EmitReturnBlock(llvm::IRBuilder<>* b);
305 
306 // Emits `if (condition) return`. Assumes that the current function returns
307 // void.
308 //
309 // Can either use a supplied `return_block`, or generate a new one.
310 void EmitEarlyReturn(llvm::Value* condition, llvm::IRBuilder<>* b,
311                      llvm::BasicBlock* return_block = nullptr);
312 
313 }  // namespace llvm_ir
314 }  // namespace xla
315 
316 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
317