1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17
18 #include <stddef.h>
19 #include <string.h>
20
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <stack>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29
30 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
31 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
32
33 #include "absl/base/call_once.h"
34 #include "absl/container/flat_hash_map.h"
35 #include "absl/strings/str_cat.h"
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/ADT/Triple.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/LLVMContext.h"
42 #include "llvm/IR/Mangler.h"
43 #include "llvm/IR/Module.h"
44 #include "llvm/IR/Verifier.h"
45 #include "llvm/MC/TargetRegistry.h"
46 #include "llvm/Object/ObjectFile.h"
47 #include "llvm/Support/CodeGen.h"
48 #include "llvm/Support/CommandLine.h"
49 #include "llvm/Support/Error.h"
50 #include "llvm/Support/TargetSelect.h"
51 #include "llvm/Target/TargetMachine.h"
52 #include "llvm/Target/TargetOptions.h"
53 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
54 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" // from @llvm-project
55 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" // from @llvm-project
56 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" // from @llvm-project
57 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" // from @llvm-project
58 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" // from @llvm-project
59 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" // from @llvm-project
60 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project
61 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" // from @llvm-project
62 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" // from @llvm-project
63 #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h" // from @llvm-project
64 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project
65 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
66 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
67 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h" // from @llvm-project
68 #include "mlir/Dialect/Bufferization/Transforms/Passes.h" // from @llvm-project
69 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
70 #include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project
71 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
72 #include "mlir/Dialect/LLVMIR/LLVMTypes.h" // from @llvm-project
73 #include "mlir/Dialect/Linalg/IR/Linalg.h" // from @llvm-project
74 #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project
75 #include "mlir/Dialect/MemRef/Transforms/Passes.h" // from @llvm-project
76 #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
77 #include "mlir/Dialect/Shape/Transforms/Passes.h" // from @llvm-project
78 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
79 #include "mlir/Dialect/Vector/IR/VectorOps.h" // from @llvm-project
80 #include "mlir/IR/Builders.h" // from @llvm-project
81 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
82 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
83 #include "mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
84 #include "mlir/Pass/PassManager.h" // from @llvm-project
85 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" // from @llvm-project
86 #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project
87 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h" // from @llvm-project
88 #include "mlir/Transforms/Passes.h" // from @llvm-project
89 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
90 #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h"
91 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
92 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
93 #include "tensorflow/compiler/xla/literal.h"
94 #include "tensorflow/compiler/xla/map_util.h"
95 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h"
96 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
97 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h"
98 #include "tensorflow/compiler/xla/protobuf_util.h"
99 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
100 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
101 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
102 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
103 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
104 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
105 #include "tensorflow/compiler/xla/service/bitcast_dtypes_expander.h"
106 #include "tensorflow/compiler/xla/service/broadcast_canonicalizer.h"
107 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
108 #include "tensorflow/compiler/xla/service/call_inliner.h"
109 #include "tensorflow/compiler/xla/service/change_op_data_type.h"
110 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
111 #include "tensorflow/compiler/xla/service/comparison_expander.h"
112 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
113 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
114 #include "tensorflow/compiler/xla/service/conditional_to_select.h"
115 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
116 #include "tensorflow/compiler/xla/service/copy_insertion.h"
117 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
118 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
119 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
120 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
121 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
122 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
123 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
124 #include "tensorflow/compiler/xla/service/cpu/cpu_shape_verifier.h"
125 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
126 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
127 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
128 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
129 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
130 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
131 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
132 #include "tensorflow/compiler/xla/service/dump.h"
133 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
134 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
135 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
136 #include "tensorflow/compiler/xla/service/eigh_expander.h"
137 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
138 #include "tensorflow/compiler/xla/service/gather_expander.h"
139 #include "tensorflow/compiler/xla/service/hlo.pb.h"
140 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
141 #include "tensorflow/compiler/xla/service/hlo_computation.h"
142 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
143 #include "tensorflow/compiler/xla/service/hlo_cse.h"
144 #include "tensorflow/compiler/xla/service/hlo_dce.h"
145 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
146 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
147 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
148 #include "tensorflow/compiler/xla/service/hlo_module.h"
149 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
150 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
151 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
152 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
153 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
154 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
155 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
156 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
157 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
158 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_command_line_options.h"
159 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
160 #include "tensorflow/compiler/xla/service/logistic_expander.h"
161 #include "tensorflow/compiler/xla/service/map_inliner.h"
162 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
163 #include "tensorflow/compiler/xla/service/optimization_barrier_expander.h"
164 #include "tensorflow/compiler/xla/service/qr_expander.h"
165 #include "tensorflow/compiler/xla/service/reduce_decomposer.h"
166 #include "tensorflow/compiler/xla/service/reduce_scatter_decomposer.h"
167 #include "tensorflow/compiler/xla/service/reshape_decomposer.h"
168 #include "tensorflow/compiler/xla/service/reshape_mover.h"
169 #include "tensorflow/compiler/xla/service/result_caster.h"
170 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
171 #include "tensorflow/compiler/xla/service/rng_expander.h"
172 #include "tensorflow/compiler/xla/service/scatter_expander.h"
173 #include "tensorflow/compiler/xla/service/select_and_scatter_expander.h"
174 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
175 #include "tensorflow/compiler/xla/service/sharding_remover.h"
176 #include "tensorflow/compiler/xla/service/slice_sinker.h"
177 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
178 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
179 #include "tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.h"
180 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
181 #include "tensorflow/compiler/xla/service/transpose_folding.h"
182 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
183 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
184 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
185 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
186 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
187 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
188 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
189 #include "tensorflow/compiler/xla/status_macros.h"
190 #include "tensorflow/compiler/xla/statusor.h"
191 #include "tensorflow/compiler/xla/types.h"
192 #include "tensorflow/compiler/xla/util.h"
193 #include "tensorflow/compiler/xla/xla_data.pb.h"
194 #include "tensorflow/core/platform/errors.h"
195 #include "tensorflow/core/platform/status.h"
196 #include "tensorflow/core/protobuf/error_codes.pb.h"
197
198 namespace {
199
200 // We need to explicitly load all the dialects we will involved in emitting the
201 // IR. This is only needed because of how MLIR is bolted into XLA and does not
202 // make use of the MLIR infrastructure (like using a proper pass pipeline).
203 // Hopefully this will all go away at some point in favor of a better
204 // integration.
LoadMLIRDialects(mlir::MLIRContext & context)205 void LoadMLIRDialects(mlir::MLIRContext& context) {
206 context.loadDialect<mlir::arith::ArithmeticDialect,
207 mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
208 mlir::vector::VectorDialect, mlir::func::FuncDialect,
209 mlir::AffineDialect, mlir::tensor::TensorDialect,
210 mlir::xla_framework::XLAFrameworkDialect>();
211 mlir::registerLLVMDialectTranslation(context);
212 }
213
214 } // namespace
215
216 namespace xla {
217
218 namespace {
219
UseMlirHloLowering(bool use_mlir,HloModule * module)220 bool UseMlirHloLowering(bool use_mlir, HloModule* module) {
221 // TODO(tpopp): The prototype currently does not properly handle constant
222 // buffers that are handled by the runtime's buffer assignmen.
223 return use_mlir &&
224 module->entry_computation()->root_instruction()->opcode() !=
225 HloOpcode::kConstant;
226 }
227
228 // For each computation in the module, determines whether that computation
229 // calls a custom-call function, either directly or indirectly (e.g. because it
230 // calls another computation that does).
231 absl::flat_hash_map<const HloComputation*, bool>
ModuleComputationsTransitivelyContainCustomCall(const HloModule & module)232 ModuleComputationsTransitivelyContainCustomCall(const HloModule& module) {
233 absl::flat_hash_map<const HloComputation*, bool> custom_call_map;
234 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
235
236 // Can never fail because we always return an OK status from the visitor.
237 TF_CHECK_OK(call_graph->VisitNodes([&custom_call_map](
238 const CallGraphNode& node) {
239 const HloComputation* computation = node.computation();
240
241 for (const HloInstruction* instruction : computation->instructions()) {
242 // The computation contains a custom-call instruction directly.
243 if (DynCast<HloCustomCallInstruction>(instruction)) {
244 custom_call_map[computation] = true;
245 return OkStatus();
246 }
247 // The computation calls something that contains a custom-call
248 // instruction (directly or indirectly). This lookup relies on the call
249 // graph traversing callees before callers, so that the map is always
250 // populated for all callees at this point.
251 for (const HloComputation* callee : instruction->called_computations()) {
252 bool callee_contains_custom_call = FindOrDie(custom_call_map, callee);
253 if (callee_contains_custom_call) {
254 custom_call_map[computation] = true;
255 return OkStatus();
256 }
257 }
258 }
259
260 custom_call_map[computation] = false;
261 return OkStatus();
262 }));
263
264 return custom_call_map;
265 }
266
267 } // namespace
268
269 namespace cpu {
270 using BufferInfo = cpu_function_runtime::BufferInfo;
271
CpuAotCompilationOptions(std::string triple,std::string cpu_name,std::string features,std::string entry_point_name,RelocationModel relocation_model)272 CpuAotCompilationOptions::CpuAotCompilationOptions(
273 std::string triple, std::string cpu_name, std::string features,
274 std::string entry_point_name, RelocationModel relocation_model)
275 : triple_(std::move(triple)),
276 cpu_name_(std::move(cpu_name)),
277 features_(std::move(features)),
278 entry_point_name_(std::move(entry_point_name)),
279 relocation_model_(relocation_model) {}
280
281 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
282
PlatformId() const283 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
284 return se::host::kHostPlatformId;
285 }
286
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64_t result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)287 CpuAotCompilationResult::CpuAotCompilationResult(
288 ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
289 int64_t result_buffer_index,
290 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
291 : object_file_data_(std::move(object_file_data)),
292 buffer_infos_(std::move(buffer_infos)),
293 result_buffer_index_(result_buffer_index),
294 hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
295
296 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
297
CpuCompiler()298 CpuCompiler::CpuCompiler() {
299 // Initialize LLVM the first time the CpuCompiler is initialized.
300 static bool llvm_initialized = []() {
301 InitializeLLVMTarget();
302 return true;
303 }();
304 (void)llvm_initialized;
305 }
306
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_execs,const CompileOptions & options)307 StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
308 std::unique_ptr<HloModuleGroup> module_group,
309 std::vector<std::vector<se::StreamExecutor*>> stream_execs,
310 const CompileOptions& options) {
311 for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
312 if (se_vector.size() != 1) {
313 return Unimplemented(
314 "Model partitioning not implemented for the CPU compiler");
315 }
316 }
317 return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
318 }
319
InitializeLLVMTarget()320 /* static */ void CpuCompiler::InitializeLLVMTarget() {
321 // Initialize LLVM's MC layer for the native target.
322 llvm::InitializeNativeTarget();
323 llvm::InitializeNativeTargetAsmPrinter();
324 }
325
326 namespace {
327
328 // LLVM makes certain options configurable only through its command-line
329 // options; it provide the ParseCommandLineOptions function that lets us set
330 // flags at runtime. However, since these flags are global we want to avoid
331 // multiple invocations of the LLVM compilation pipeline with a different set of
332 // flags. Therefore, we only pass command-line flags to LLVM once, before the
333 // first module is compiled.
334 absl::once_flag llvm_command_line_options_initialized;
335
336 // This visitor records which HLO instructions should have profiling information
337 // recorded.
338 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
339 public:
340 static StatusOr<absl::flat_hash_map<const HloInstruction*, int64_t>>
GetCandidatesForComputation(const HloComputation & computation,const absl::flat_hash_map<const HloInstruction *,int64_t> & assigned_indices)341 GetCandidatesForComputation(
342 const HloComputation& computation,
343 const absl::flat_hash_map<const HloInstruction*, int64_t>&
344 assigned_indices) {
345 absl::flat_hash_map<const HloInstruction*, int64_t> hlo_to_profile_idx;
346 CollectProfileCandidates profile_candidates_for_computation(
347 &hlo_to_profile_idx, assigned_indices);
348 TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
349 return hlo_to_profile_idx;
350 }
351
352 private:
CollectProfileCandidates(absl::flat_hash_map<const HloInstruction *,int64_t> * hlo_to_profile_idx,const absl::flat_hash_map<const HloInstruction *,int64_t> & assigned_indices)353 CollectProfileCandidates(
354 absl::flat_hash_map<const HloInstruction*, int64_t>* hlo_to_profile_idx,
355 const absl::flat_hash_map<const HloInstruction*, int64_t>&
356 assigned_indices)
357 : hlo_to_profile_idx_(hlo_to_profile_idx),
358 assigned_indices_(assigned_indices) {}
359
DefaultAction(HloInstruction * hlo_instruction)360 Status DefaultAction(HloInstruction* hlo_instruction) override {
361 hlo_to_profile_idx_->insert(
362 {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
363 return OkStatus();
364 }
365
HandleCall(HloInstruction * call)366 Status HandleCall(HloInstruction* call) override {
367 TF_RETURN_IF_ERROR(DefaultAction(call));
368 CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
369 assigned_indices_);
370 TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
371 return OkStatus();
372 }
373 // Recurse into "conditional" so we can profile inside of it.
HandleConditional(HloInstruction * conditional)374 Status HandleConditional(HloInstruction* conditional) override {
375 TF_RETURN_IF_ERROR(DefaultAction(conditional));
376
377 CollectProfileCandidates candidates_for_true(hlo_to_profile_idx_,
378 assigned_indices_);
379 TF_RETURN_IF_ERROR(
380 conditional->true_computation()->Accept(&candidates_for_true));
381
382 CollectProfileCandidates candidates_for_false(hlo_to_profile_idx_,
383 assigned_indices_);
384 TF_RETURN_IF_ERROR(
385 conditional->false_computation()->Accept(&candidates_for_false));
386
387 return OkStatus();
388 }
389
390 // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)391 Status HandleConstant(HloInstruction*) override { return OkStatus(); }
392 // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)393 Status HandleParameter(HloInstruction*) override { return OkStatus(); }
394 // It is important to recurse for "while" or else we risk overly coarse
395 // profiling information.
HandleWhile(HloInstruction * xla_while)396 Status HandleWhile(HloInstruction* xla_while) override {
397 TF_RETURN_IF_ERROR(DefaultAction(xla_while));
398
399 CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
400 assigned_indices_);
401 TF_RETURN_IF_ERROR(
402 xla_while->while_condition()->Accept(&candidates_for_condition));
403
404 CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
405 assigned_indices_);
406 TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
407
408 return OkStatus();
409 }
410
411 absl::flat_hash_map<const HloInstruction*, int64_t>* hlo_to_profile_idx_;
412 const absl::flat_hash_map<const HloInstruction*, int64_t>& assigned_indices_;
413 };
414
415 // Adds the HloVerifier for CPU to the given pipeline.
AddHloVerifier(HloPassPipeline * pipeline,HloVerifierOpts && opts={},bool debug_only=false)416 void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {},
417 bool debug_only = false) {
418 std::unique_ptr<TargetVerifierMetadata> verifier_metadata =
419 std::make_unique<CpuVerifierMetadata>(std::move(opts));
420 if (debug_only) {
421 pipeline->AddInvariantCheckerDebug<HloVerifier>(
422 std::move(verifier_metadata), "hlo verifier (debug)");
423 } else {
424 pipeline->AddInvariantChecker<HloVerifier>(std::move(verifier_metadata),
425 "hlo verifier");
426 }
427 }
428
429 } // namespace
430
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features,bool is_mlir_compile)431 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
432 HloModule* module, bool /*is_aot_compile*/,
433 LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) {
434 if (module->config().use_spmd_partitioning()) {
435 HloPassPipeline spmd_pipeline("spmd-partitioner");
436 const int64_t num_partitions = module->config().num_partitions();
437 if (num_partitions > 1) {
438 // Run some IR cleanup passes before running the SPMD partitioning
439 // passes.
440 AddHloVerifier(&spmd_pipeline);
441 spmd_pipeline.AddPass<CallInliner>();
442 spmd_pipeline.AddPass<ZeroSizedHloElimination>();
443 spmd_pipeline.AddPass<ConditionalCanonicalizer>();
444
445 spmd_pipeline.AddPass<ShardingPropagation>(
446 /*is_spmd=*/true, /*propagate_metadata=*/false,
447 module->config().allow_spmd_sharding_propagation_to_output());
448 spmd_pipeline.AddPass<spmd::StatefulRngSpmdPartitioner>(
449 num_partitions, module->config().replica_count());
450 } else {
451 // Remove redundant sharding ops when partition_count == 1.
452 spmd_pipeline.AddPass<ShardingRemover>();
453 spmd_pipeline.AddPass<HloDCE>();
454 }
455 TF_RETURN_IF_ERROR(spmd_pipeline.Run(module).status());
456 }
457
458 HloPassPipeline pipeline("HLO passes through layout assignment");
459 AddHloVerifier(&pipeline);
460
461 pipeline.AddPass<OperandUpcaster>();
462 pipeline.AddPass<ResultCaster>();
463
464 // Expand random number generation.
465 pipeline.AddPass<RngExpander>();
466 pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
467
468 // Remove zero-sized HLO from the input so that other passes don't have to
469 // handle it.
470 pipeline.AddPass<ZeroSizedHloElimination>();
471
472 pipeline.AddPass<DynamicIndexSplitter>();
473
474 pipeline.AddPass<ConditionalToSelect>();
475 pipeline.AddPass<MapInliner>();
476
477 pipeline.AddPass<ComparisonExpander>();
478 pipeline.AddPass<CholeskyExpander>();
479 pipeline.AddPass<QrExpander>();
480 pipeline.AddPass<EighExpander>();
481 pipeline.AddPass<TriangularSolveExpander>();
482 pipeline.AddPass<AllGatherDecomposer>();
483 pipeline.AddPass<AllToAllDecomposer>();
484 pipeline.AddPass<ReduceScatterDecomposer>();
485
486 // Inline computations with a single call site.
487 pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
488 pipeline.AddPass<BatchDotSimplification>();
489 pipeline.AddPass<DotDecomposer>();
490 // Convert BF16 operations to F32 operations so that the CPU backend can
491 // support BF16 operations without directly implementing a BF16 lowering for
492 // most ops.
493 BFloat16Support bf16;
494 pipeline.AddPass<BFloat16Normalization>(&bf16);
495 // After canonicalization, there may be more batch dots that can be
496 // simplified.
497 pipeline.AddPass<BatchDotSimplification>();
498 auto cost_model = [](HloInstruction* conv) {
499 // We need a cost model for CPUs. Currently, do nothing.
500 return false;
501 };
502 pipeline.AddPass<ConvolutionGroupConverter>(
503 /*should_expand=*/[](HloInstruction* conv) { return true; }, cost_model,
504 /*convert_batch_groups_only=*/true);
505 auto feature_group_should_expand = [](HloInstruction* conv) {
506 switch (conv->shape().element_type()) {
507 case F16:
508 case F32:
509 return false;
510 default:
511 return true;
512 }
513 };
514 pipeline.AddPass<ConvolutionGroupConverter>(
515 feature_group_should_expand, cost_model,
516 /*convert_batch_groups_only=*/false);
517 pipeline.AddPass<BatchNormExpander>(
518 /*rewrite_training_op=*/true,
519 /*rewrite_inference_op=*/true,
520 /*rewrite_grad_op=*/true);
521 pipeline.AddPass<LogisticExpander>(
522 /*expansion_type=*/LogisticExpansionType::kExp);
523 pipeline.AddPass<ConditionalCanonicalizer>();
524 pipeline.AddPass<DynamicDimensionSimplifier>();
525 auto dynamic_padder_options = DynamicPadderOptions();
526 dynamic_padder_options.shape_check_mode =
527 DynamicDimensionInference::ShapeCheckMode::kCompileTime;
528 pipeline.AddPass<DynamicPadder>(dynamic_padder_options);
529 pipeline.AddPass<SelectAndScatterExpander>();
530 pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
531 pipeline.AddPass<ConvCanonicalization>(target_machine_features);
532
533 // Run fp16 dots/convs in fp32 and then downcast the result to fp16.
534 // Justification:
535 //
536 // - This is significantly faster on our CPUs today than true fp16.
537 // - It's numerically more accurate. (Granted, this is not always
538 // desirable, thus the ability to disable this functionality.)
539 // - It matches more closely the GPU's behavior on fp16 dot/conv, where
540 // accumulation happens in f32.
541 if (!module->config().debug_options().xla_cpu_strict_dot_conv_math()) {
542 pipeline.AddPass<ChangeOpDataType>(
543 F16, F32, [](const HloInstruction* instr) {
544 return instr->opcode() == HloOpcode::kDot ||
545 instr->opcode() == HloOpcode::kConvolution;
546 });
547 }
548
549 // Run the following passes to a fixed point.
550 [&pipeline =
551 pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
552 AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true);
553
554 AlgebraicSimplifierOptions options;
555 options.set_enable_dot_strength_reduction(false);
556 // TODO(b/209827141): XLA:CPU doesn't propagate NaN through min/max, but
557 // other platforms do, so it should be changed.
558 options.set_minmax_propagate_nan(false);
559 pipeline.AddPass<AlgebraicSimplifier>(options);
560 pipeline.AddPass<SortSimplifier>();
561 pipeline.AddPass<HloDCE>();
562 pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
563
564 // Needs to happen after algebraic simplifier.
565 pipeline.AddPass<TreeReductionRewriter>();
566
567 // BatchNormExpander can create zero-sized ops, so zero-sized HLO
568 // elimination has to come after that pass.
569 pipeline.AddPass<ZeroSizedHloElimination>();
570
571 pipeline.AddPass<WhileLoopInvariantCodeMotion>();
572 pipeline.AddPass<TupleSimplifier>();
573 pipeline.AddPass<WhileLoopConstantSinking>();
574 pipeline.AddPass<WhileLoopSimplifier>();
575
576 // TODO(b/134075051): Re-enable after b/134075051 is fixed.
577 // pipeline.AddPass<SliceSinker>();
578
579 pipeline.AddPass<HloDCE>();
580 pipeline.AddPass<ReshapeMover>();
581 pipeline.AddPass<HloConstantFolding>();
582 pipeline.AddPass<ConditionalSimplifier>();
583 }();
584 pipeline.AddPass<BitcastDtypesExpander>();
585
586 // XLA lowers topk to a libcall while the MLIR based pipeline does not yet
587 // support libcalls. Disable this for now.
588 if (!is_mlir_compile) {
589 pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64_t) {
590 return sort->operand(0)->shape().element_type() == F32;
591 });
592 }
593 pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
594 pipeline.AddPass<TransposeFolding>(
595 [&](const HloInstruction& dot, int64_t operand) -> StatusOr<bool> {
596 if (DotImplementationCanHandleTranspose(dot,
597 *target_machine_features)) {
598 return TransposeFolding::IsRowColumnTransposeDotOperand(dot, operand);
599 }
600 return false;
601 },
602 TransposeFolding::NeverFoldTranspose);
603 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
604
605 pipeline.AddPass<OptimizationBarrierExpander>();
606 pipeline.AddPass<TupleSimplifier>();
607
608 // Layout assignment uses alias analysis, which requires the call graph to be
609 // flattened.
610 pipeline.AddPass<FlattenCallGraph>();
611 ChannelLayoutConstraints layout_constraints;
612 pipeline.AddPass<CpuLayoutAssignment>(
613 module->mutable_entry_computation_layout(), target_machine_features,
614 &layout_constraints);
615
616 return pipeline.Run(module).status();
617 }
618
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features,bool is_mlir_compile)619 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
620 HloModule* module, bool is_aot_compile,
621 LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) {
622 {
623 HloPassPipeline pipeline("hlo normalization");
624 pipeline.AddPass<ReshapeDecomposer>();
625 pipeline.AddPass<ReduceDecomposer>();
626 pipeline.AddPass<BroadcastCanonicalizer>();
627 TF_RETURN_IF_ERROR(pipeline.Run(module).status());
628 }
629
630 HloPassPipeline pipeline("HLO passes after layout assignment");
631
632 // CopyInsertion is still needed by BufferAssignment. MLIR passes will handle
633 // everything else done by XLA, but CopyInsertion is needed to interface with
634 // the existing runtime.
635 if (is_mlir_compile) {
636 pipeline.AddPass<CopyInsertion>();
637 return pipeline.Run(module).status();
638 }
639
640 // After layout assignment, use a layout-sensitive verifier.
641 pipeline.AddPass<HloPassPipeline>("after layout assignment");
642 AddHloVerifier(&pipeline, HloVerifierOpts{}.MakeLayoutSensitive(),
643 /*debug_only=*/true);
644
645 pipeline.AddPass<ReshapeDecomposer>();
646
647 // Add a fusion pass now that layout assignment is done.
648 pipeline.AddPass<CpuInstructionFusion>();
649
650 // The LayoutAssignment pass may leave behind kCopy instructions which are
651 // duplicate or NOPs, so remove them with algebraic simplification and CSE.
652 // Run this to a fixed point.
653 [&pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
654 "simplification after layout assignment")] {
655 AddHloVerifier(
656 &pipeline,
657 HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
658 LayoutAssignment::InstructionCanChangeLayout),
659 /*debug_only=*/true);
660 AlgebraicSimplifierOptions options;
661 options.set_is_layout_sensitive(true);
662 options.set_enable_dot_strength_reduction(false);
663 // TODO(b/209827141): XLA:CPU doesn't propagate NaN through min/max, but
664 // other platforms do, so it should be changed.
665 options.set_minmax_propagate_nan(false);
666 pipeline.AddPass<AlgebraicSimplifier>(options);
667 pipeline.AddPass<HloDCE>();
668 pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
669 }();
670
671 // Outline ops in the entry computation into calls to subcomputations.
672 const int max_parallelism =
673 module->config().intra_op_parallelism_threads() > 0
674 ? module->config().intra_op_parallelism_threads()
675 : tensorflow::port::NumSchedulableCPUs();
676 if (!is_aot_compile) {
677 // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
678 // Note this is not run for AOT because it would bring in thread pool
679 // and thread synchronization dependencies which would likely increase
680 // binary size (and most AOT applications are single-threaded).
681 // TODO(b/29630486) Support multi-threaded AOT.
682 pipeline.AddPass<ParallelTaskAssigner>(
683 max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
684 }
685 // Copy insertion should be performed immediately before IR emission to
686 // avoid inserting unnecessary copies (later pass adds an instruction which
687 // materializes the value) or missing a necessary copy (later pass removes
688 // an instruction which materializes a value). DCE must be run immediately
689 // before (and sometime after) copy insertion, to avoid dead code from
690 // interfering with the rewrites.
691 pipeline.AddPass<HloDCE>();
692 pipeline.AddPass<CopyInsertion>();
693 pipeline.AddPass<HloDCE>();
694 return pipeline.Run(module).status();
695 }
696
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine,bool is_mlir_compile)697 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
698 llvm::TargetMachine* target_machine,
699 bool is_mlir_compile) {
700 LLVMTargetMachineFeatures target_machine_features(target_machine);
701 TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(
702 module, is_aot_compile, &target_machine_features, is_mlir_compile));
703
704 return RunHloPassesAfterLayoutAssn(
705 module, is_aot_compile, &target_machine_features,
706 UseMlirHloLowering(is_mlir_compile, module));
707 }
708
709 namespace {
710
711 // Align buffers to 16-byte boundaries.
memory_alignment(LogicalBuffer::Color)712 int64_t memory_alignment(LogicalBuffer::Color) {
713 return cpu_function_runtime::MinAlign();
714 }
715
CompilerTargetOptions(const HloModuleConfig & module_config)716 llvm::TargetOptions CompilerTargetOptions(
717 const HloModuleConfig& module_config) {
718 llvm::TargetOptions target_options;
719 // Always allow FMA fusion. This increases precision instead of decreasing it.
720 target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
721 return target_options;
722 }
723
CodeGenOptLevel(const HloModuleConfig & module_config)724 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
725 VLOG(2) << "backend_optimization_level: "
726 << module_config.debug_options().xla_backend_optimization_level();
727 switch (module_config.debug_options().xla_backend_optimization_level()) {
728 case 1:
729 return llvm::CodeGenOpt::Less;
730 case 2:
731 return llvm::CodeGenOpt::Default;
732 case 3:
733 return llvm::CodeGenOpt::Aggressive;
734 default:
735 return llvm::CodeGenOpt::None;
736 }
737 }
738
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)739 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
740 const HloModule& hlo_module,
741 const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
742 const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
743 // Create the IR hooks. If applicable, each IR hook does the following:
744 //
745 // * Calls the user supplied module hook.
746 // * Writes out the IR to a file in the output directory designated by
747 // --xla_dump_to
748 const HloModule* hlo_module_ptr = &hlo_module;
749 auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
750 hlo_module_ptr](bool optimized,
751 const llvm::Module& llvm_module) {
752 const auto& user_hook =
753 !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
754 if (user_hook) {
755 user_hook(llvm_module);
756 }
757 llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
758 };
759 return {[hook](const llvm::Module& llvm_module) {
760 return hook(/*optimized=*/false, llvm_module);
761 },
762 [hook](const llvm::Module& llvm_module) {
763 return hook(/*optimized=*/true, llvm_module);
764 }};
765 }
766
VerifyLlvmModule(const llvm::Module & llvm_module)767 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
768 XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
769
770 std::string err;
771 llvm::raw_string_ostream err_stream(err);
772
773 // verifyModule() returns true if the module is broken.
774 TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
775 << "Invalid LLVM IR before optimizations:\n"
776 << err_stream.str()
777 << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
778 "Rerun with --xla_dump_to to get the IR. ";
779 return OkStatus();
780 }
781
CreateHloProfilingArtifacts(const HloModule & module,absl::flat_hash_map<const HloInstruction *,int64_t> * instruction_to_profile_idx,absl::flat_hash_map<const HloComputation *,int64_t> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)782 Status CreateHloProfilingArtifacts(
783 const HloModule& module,
784 absl::flat_hash_map<const HloInstruction*, int64_t>*
785 instruction_to_profile_idx,
786 absl::flat_hash_map<const HloComputation*, int64_t>*
787 computation_to_profile_idx,
788 std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
789 std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
790 *hlo_profile_index_map = std::make_unique<HloProfileIndexMap>(module);
791 const HloComputation& entry_computation = *module.entry_computation();
792
793 TF_ASSIGN_OR_RETURN(
794 *instruction_to_profile_idx,
795 CollectProfileCandidates::GetCandidatesForComputation(
796 entry_computation,
797 (*hlo_profile_index_map)->instruction_to_profile_idx()));
798
799 auto shape_size_bytes = [](const Shape& shape) {
800 // On the cpu, opaques are pointers.
801 if (shape.IsOpaque()) {
802 return static_cast<int64_t>(sizeof(void*));
803 }
804 return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
805 };
806
807 HloCostAnalysis cost_analysis(shape_size_bytes);
808 TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
809 *hlo_profile_printer_data = CreateHloProfilePrinterData(
810 **hlo_profile_index_map, cost_analysis, entry_computation.name());
811 *computation_to_profile_idx =
812 (*hlo_profile_index_map)->computation_to_profile_idx();
813
814 return OkStatus();
815 }
816
817 } // namespace
818
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,const CompileOptions &)819 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
820 std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
821 const CompileOptions& /*options*/) {
822 std::unique_ptr<llvm::TargetMachine> jit_target_machine =
823 SimpleOrcJIT::InferTargetMachineForJIT(
824 CompilerTargetOptions(module->config()),
825 CodeGenOptLevel(module->config()));
826
827 TF_RETURN_IF_ERROR(RunHloPasses(
828 module.get(), /*is_aot_compile=*/false, jit_target_machine.get(),
829 /*is_mlir_compile=*/
830 module->config().debug_options().xla_cpu_enable_mlir_lowering()));
831 return std::move(module);
832 }
833
AssignBuffers(const HloModule * module)834 StatusOr<std::unique_ptr<BufferAssignment>> CpuCompiler::AssignBuffers(
835 const HloModule* module) {
836 // Select an order for emitting the HLO instructions for each computation.
837 // Using this sequence enables tighter buffer liveness analysis and reduced
838 // memory usage (as compared to using DependencyHloOrdering).
839 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
840 ScheduleModule(module, BufferSizeBytesFunction(),
841 ComputationSchedulerToModuleScheduler(
842 DFSMemoryScheduler)));
843
844 // Run buffer allocation on the HLO graph.
845 TF_ASSIGN_OR_RETURN(
846 std::unique_ptr<BufferAssignment> assignment,
847 BufferAssigner::Run(module,
848 std::make_unique<SequentialHloOrdering>(schedule),
849 BufferSizeBytesFunction(), memory_alignment,
850 /*allocate_buffers_for_constants=*/true));
851
852 return std::move(assignment);
853 }
854
855 namespace {
856
857 // Post-compilation callback functor for use by SimpleOrcJIT.
858 //
859 // Dumps machine code if dumping is enabled for the module.
860 struct OrcJITPostCompilationHook {
861 // Gets an std::function that implements this hook.
Createxla::cpu::__anon660623761311::OrcJITPostCompilationHook862 static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
863 const HloModule* module) {
864 // This struct is not copyable, but std::functions must be. So to create an
865 // std::function out of this struct, we have to wrap it in a shared_ptr.
866 auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
867 return [wrapped](const llvm::object::ObjectFile& obj_file) {
868 (*wrapped)(obj_file);
869 };
870 }
871
872 // Constructor can't be private because we want to call it from
873 // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anon660623761311::OrcJITPostCompilationHook874 explicit OrcJITPostCompilationHook(const HloModule* module)
875 : module(module) {}
876
877 private:
operator ()xla::cpu::__anon660623761311::OrcJITPostCompilationHook878 void operator()(const llvm::object::ObjectFile& obj_file) {
879 if (!DumpingEnabledForHloModule(*module)) {
880 return;
881 }
882 DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
883 absl::string_view(obj_file.getData().data(),
884 obj_file.getData().size()));
885 }
886
887 const HloModule* module;
888 };
889
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)890 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
891 llvm_ir::InitializeLLVMCommandLineOptions(
892 config.debug_options().xla_backend_extra_options());
893 }
894
LowerMLIRModule(mlir::ModuleOp mlir_module,mlir::MLIRContext & mlir_context)895 Status LowerMLIRModule(mlir::ModuleOp mlir_module,
896 mlir::MLIRContext& mlir_context) {
897 LoadMLIRDialects(mlir_context);
898 mlir::PassManager pm(&mlir_context);
899 // Resolve all shape constraints (e.g. broadcast constraints that can be
900 // proved statically and changed to const witness) early to allow more
901 // efficient broadcast operations moving.
902 // Move up broadcasting operations to allow for more fusion opportunities.
903 pm.addPass(mlir::createInlinerPass());
904 pm.addPass(mlir::mhlo::createExpandHloTuplesPass("main"));
905 // TODO(b/233771980): Remove once custom_call doesn't use tuples.
906 pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createFlattenTuplePass());
907 pm.addNestedPass<mlir::func::FuncOp>(
908 mlir::mhlo::createLegalizeGeneralDotPass());
909 pm.addNestedPass<mlir::func::FuncOp>(
910 mlir::mhlo::createBroadcastPropagationPass());
911 pm.addPass(mlir::createCSEPass());
912 pm.addPass(mlir::createCanonicalizerPass());
913
914 // Transform HLO operations to Linalg.
915 pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeSortPass());
916 pm.addNestedPass<mlir::func::FuncOp>(
917 mlir::mhlo::createLegalizeControlFlowPass());
918 pm.addPass(::mlir::mhlo::createLegalizeToArithmeticPass());
919 pm.addNestedPass<mlir::func::FuncOp>(
920 mlir::mhlo::createLegalizeHloToLinalgPass());
921
922 // Lower index cast on tensors to tensor.generate.
923 pm.addNestedPass<mlir::func::FuncOp>(mlir::createLowerIndexCastPass());
924
925 pm.addPass(mlir::mhlo::createConvertToSignlessPass());
926
927 // Lower shape dialect to standard to enable linalg canonicalizations (e.g.
928 // use linalg inputs instead of outputs for memref.dim operations).
929 pm.addNestedPass<mlir::func::FuncOp>(mlir::createShapeSimplification());
930 pm.addNestedPass<mlir::func::FuncOp>(mlir::createShapeToShapeLowering());
931 pm.addPass(mlir::createConvertShapeToStandardPass());
932 pm.addNestedPass<mlir::func::FuncOp>(
933 mlir::createConvertShapeConstraintsPass());
934
935 // Fuse Linalg on tensors operations.
936 pm.addPass(mlir::createCSEPass());
937 pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
938 pm.addPass(mlir::createCanonicalizerPass());
939 pm.addNestedPass<mlir::func::FuncOp>(
940 mlir::createLinalgElementwiseOpFusionPass());
941 pm.addPass(mlir::createReconcileUnrealizedCastsPass());
942 pm.addPass(mlir::createConvertTensorToLinalgPass());
943 pm.addNestedPass<mlir::func::FuncOp>(
944 mlir::createLinalgInitTensorToAllocTensorPass());
945
946 // Always run canonicalizer (which does dead code removal) before
947 // bufferizing anything.
948 pm.addPass(mlir::createCanonicalizerPass());
949 pm.addPass(mlir::hlo::createOneShotBufferizePass());
950
951 // Handle framework specific requirements for buffers and then insert
952 // deallocations for temporary buffers.
953 pm.addNestedPass<mlir::func::FuncOp>(mlir::createConvertLinalgToLoopsPass());
954 pm.addPass(mlir::createCSEPass());
955 pm.addPass(mlir::createCanonicalizerPass());
956 pm.addPass(mlir::bufferization::createBufferResultsToOutParamsPass());
957 pm.addPass(mlir::mhlo::CreateOutlineWithXLAFrameworkPass());
958 pm.addPass(mlir::createInlinerPass());
959 pm.addNestedPass<mlir::func::FuncOp>(
960 mlir::bufferization::createBufferDeallocationPass());
961
962 pm.addPass(mlir::createBufferizationToMemRefPass());
963
964 // Specilize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat,
965 // and immediately canonicalize to clean up not taken branches.
966 // pm.addNestedPass<mlir::func::FuncOp>(CreateLinalgMatmulSpecializationPass());
967 pm.addPass(mlir::createCanonicalizerPass());
968
969 // Tile and vectorize linalg operation using Linalg Codegen Strategy.
970 // pm.addNestedPass<mlir::func::FuncOp>(CreateCodegenStrategyForMatMulPass());
971
972 // TODO(tpopp): Move hits to mlir::hlo::createGenericHostToLLVMPass?
973 pm.addNestedPass<mlir::func::FuncOp>(
974 mlir::createConvertComplexToStandardPass());
975
976 pm.addPass(mlir::createCSEPass());
977 pm.addPass(mlir::createCanonicalizerPass());
978
979 mlir::VectorTransferToSCFOptions vec_to_scf_options;
980 vec_to_scf_options.unroll = true;
981 pm.addNestedPass<mlir::func::FuncOp>(
982 mlir::createConvertVectorToSCFPass(vec_to_scf_options));
983 pm.addNestedPass<mlir::func::FuncOp>(
984 mlir::arith::createArithmeticExpandOpsPass());
985 pm.addNestedPass<mlir::func::FuncOp>(mlir::memref::createExpandOpsPass());
986 pm.addNestedPass<mlir::func::FuncOp>(mlir::createLowerAffinePass());
987 pm.addPass(mlir::mhlo::CreateLegalizeXLAFrameworkToLLVMPass());
988 pm.addPass(mlir::hlo::createGenericHostToLLVMPass());
989 pm.addPass(mlir::createReconcileUnrealizedCastsPass());
990 if (pm.run(mlir_module).failed()) {
991 mlir_module->dump();
992 return tensorflow::errors::Internal(
993 "Failed to compile through MLIR pipeline");
994 }
995
996 return OkStatus();
997 }
998
createMLIRModule(HloModule * module,mlir::MLIRContext & mlir_context,BufferAssignment * assignment)999 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> createMLIRModule(
1000 HloModule* module, mlir::MLIRContext& mlir_context,
1001 BufferAssignment* assignment) {
1002 LoadMLIRDialects(mlir_context);
1003 mlir::OpBuilder builder(&mlir_context);
1004 auto mlir_module = builder.create<mlir::ModuleOp>(builder.getUnknownLoc());
1005 TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(mlir_module, module));
1006
1007 // Add buffer mappings. The first attribute is the index of the slice, the
1008 // second is a boolean attribute on whether the allocation is writeable.
1009 llvm::SmallVector<std::pair<mlir::Attribute, mlir::Attribute>>
1010 operand_mapping;
1011 for (auto i : module->entry_computation()->parameter_instructions()) {
1012 auto slice = assignment->GetUniqueTopLevelSlice(i);
1013 operand_mapping.emplace_back(
1014 builder.getI32IntegerAttr(static_cast<int32_t>(slice->index())),
1015 builder.getBoolAttr(!slice->allocation()->is_readonly()));
1016 }
1017
1018 auto root_instr = module->entry_computation()->root_instruction();
1019 auto output_allocation = assignment->GetUniqueTopLevelOutputSlice();
1020
1021 // Gather mappings to each element in the tuple if necessary
1022 llvm::SmallVector<mlir::Attribute> result_inner_mapping;
1023 if (output_allocation->allocation()->is_tuple()) {
1024 for (auto i : llvm::seq<int>(0, root_instr->shape().tuple_shapes_size())) {
1025 result_inner_mapping.push_back(mlir::IntegerAttr::get(
1026 mlir::IntegerType::get(&mlir_context, 64),
1027 assignment->GetUniqueSlice(root_instr, {i})->index()));
1028 }
1029 }
1030
1031 auto result_mapping = builder.getI32IntegerAttr(
1032 static_cast<int32_t>(output_allocation->index()));
1033 mlir_module->walk([&](mlir::func::FuncOp f) {
1034 if (f.getSymName() == "main") {
1035 for (auto& p : llvm::enumerate(operand_mapping)) {
1036 f.setArgAttr(p.index(), "xla_framework.input_mapping", p.value().first);
1037 // Mark argument as (non-)writeable for bufferization. This ensures that
1038 // entry parameters are not overwritten.
1039 f.setArgAttr(p.index(), "bufferization.writable", p.value().second);
1040 }
1041 f->setAttr("xla_framework.result_mapping", result_mapping);
1042 }
1043
1044 if (output_allocation->allocation()->is_tuple()) {
1045 f->setAttr("xla_framework.result_inner_mapping",
1046 mlir::ArrayAttr::get(f.getContext(), result_inner_mapping));
1047 }
1048 });
1049 return {mlir_module};
1050 }
1051
1052 struct ComputationToEmit {
1053 HloComputation* computation;
1054
1055 // Are we emitting this computation with fast-math reassociation enabled?
1056 // We enable reassociation for reductions because it has a significant
1057 // performance impact.
1058 bool allow_reassociation;
1059
operator ==xla::cpu::__anon660623761311::ComputationToEmit1060 bool operator==(const ComputationToEmit& other) const {
1061 return computation == other.computation &&
1062 allow_reassociation == other.allow_reassociation;
1063 }
1064
1065 template <typename H>
AbslHashValue(H h,const ComputationToEmit & c)1066 friend H AbslHashValue(H h, const ComputationToEmit& c) {
1067 return H::combine(std::move(h), c.computation, c.allow_reassociation);
1068 }
1069 };
1070
SubcomputationEmissionOrder(HloComputation * root)1071 std::vector<ComputationToEmit> SubcomputationEmissionOrder(
1072 HloComputation* root) {
1073 absl::flat_hash_set<ComputationToEmit> visited;
1074 std::vector<ComputationToEmit> postorder;
1075
1076 // agenda of (node, leave) pairs.
1077 std::stack<std::pair<ComputationToEmit, bool>> agenda;
1078 agenda.emplace(ComputationToEmit{root, false}, false);
1079 while (!agenda.empty()) {
1080 ComputationToEmit c;
1081 bool leave;
1082 std::tie(c, leave) = agenda.top();
1083 agenda.pop();
1084
1085 if (leave) {
1086 postorder.push_back(c);
1087 continue;
1088 }
1089
1090 if (visited.insert(c).second) {
1091 agenda.emplace(c, true);
1092 for (auto* instruction : c.computation->instructions()) {
1093 bool allow_reassociation =
1094 instruction->opcode() == HloOpcode::kAllReduce ||
1095 instruction->opcode() == HloOpcode::kReduce ||
1096 instruction->opcode() == HloOpcode::kReduceWindow;
1097 for (auto it = instruction->called_computations().rbegin();
1098 it != instruction->called_computations().rend(); ++it) {
1099 HloComputation* called_computation = *it;
1100 ComputationToEmit callee{
1101 called_computation, c.allow_reassociation || allow_reassociation};
1102 if (!visited.contains(callee)) {
1103 agenda.emplace(callee, false);
1104 }
1105 }
1106 }
1107 }
1108 }
1109 DCHECK(!postorder.empty() && postorder.back().computation == root);
1110 postorder.pop_back();
1111 return postorder;
1112 }
1113
1114 } // namespace
1115
1116 StatusOr<std::unique_ptr<CpuExecutable>>
CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module)1117 CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
1118 ModuleHook pre_optimization_ir_hook;
1119 ModuleHook post_optimization_ir_hook;
1120 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
1121 GetIRModuleHooks(*module, user_pre_optimization_hook_,
1122 user_post_optimization_hook_);
1123
1124 // Compile must be thread-safe so create a new LLVM context for the module.
1125 mlir::MLIRContext mlir_context;
1126 LoadMLIRDialects(mlir_context);
1127 auto llvm_context = std::make_unique<llvm::LLVMContext>();
1128 auto llvm_module =
1129 std::make_unique<llvm::Module>("__compute_module", *llvm_context);
1130
1131 auto jit = SimpleOrcJIT::Create(
1132 CompilerTargetOptions(module->config()),
1133 CodeGenOptLevel(module->config()),
1134 options::OptimizeForSizeRequested(module->config()),
1135 module->config().debug_options().xla_llvm_disable_expensive_passes(),
1136 llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
1137 post_optimization_ir_hook,
1138 OrcJITPostCompilationHook::Create(module.get()));
1139 if (!jit) {
1140 return InternalError("Creating JIT failed: %s",
1141 llvm::toString(jit.takeError()));
1142 }
1143 llvm_module->setDataLayout((*jit)->data_layout());
1144 llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
1145
1146 HloComputation* entry_computation = module->entry_computation();
1147 absl::flat_hash_map<const HloInstruction*, int64_t>
1148 instruction_to_profile_idx;
1149 absl::flat_hash_map<const HloComputation*, int64_t>
1150 computation_to_profile_idx;
1151 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
1152 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
1153 if (module->config().hlo_profiling_enabled()) {
1154 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
1155 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
1156 &hlo_profile_index_map, &hlo_profile_printer_data));
1157 }
1158
1159 // Cache these flags here since we'll want to access them after the module's
1160 // ownership is std::moved.
1161 const bool embed_ir_in_executable =
1162 module->config().debug_options().xla_embed_ir_in_executable();
1163
1164 // Select an order for emitting the HLO instructions for each
1165 // computation. Using this sequence enables tighter buffer liveness analysis
1166 // and reduced memory usage (as compared to using DependencyHloOrdering).
1167 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
1168 ScheduleModule(module.get(), BufferSizeBytesFunction(),
1169 ComputationSchedulerToModuleScheduler(
1170 DFSMemoryScheduler)));
1171
1172 // Run buffer allocation on the HLO graph.
1173 TF_ASSIGN_OR_RETURN(
1174 std::unique_ptr<BufferAssignment> assignment,
1175 BufferAssigner::Run(module.get(),
1176 std::make_unique<SequentialHloOrdering>(schedule),
1177 BufferSizeBytesFunction(), memory_alignment,
1178 /*allocate_buffers_for_constants=*/true));
1179 DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
1180
1181 // Each computation is a single function. Emit all embedded computations
1182 // before the entry computation. The order of computations returned from
1183 // GetEmbeddedComputations guarantees that a called computation occurs
1184 // before a caller computation.
1185
1186 std::string function_name;
1187 if (UseMlirHloLowering(
1188 module->config().debug_options().xla_cpu_enable_mlir_lowering(),
1189 module.get())) {
1190 TF_ASSIGN_OR_RETURN(
1191 auto mlir_module,
1192 createMLIRModule(module.get(), mlir_context, assignment.get()));
1193 TF_RETURN_IF_ERROR(LowerMLIRModule(*mlir_module, mlir_context));
1194
1195 function_name = entry_computation->name();
1196 // TODO(kramerb): Don't rely on the exact function name.
1197 llvm::cast<mlir::LLVM::LLVMFuncOp>(
1198 mlir_module->lookupSymbol("main_xla_framework"))
1199 .setName(function_name);
1200
1201 llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, *llvm_context);
1202 if (!llvm_module) {
1203 return InternalError("Translation to LLVM IR failed");
1204 }
1205 llvm_module->setDataLayout((*jit)->data_layout());
1206 llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
1207 } else {
1208 LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
1209 IrEmitter ir_emitter(
1210 &mlir_context, *module, *assignment, llvm_module.get(),
1211 std::move(instruction_to_profile_idx),
1212 std::move(computation_to_profile_idx),
1213 ModuleComputationsTransitivelyContainCustomCall(*module),
1214 &target_machine_features,
1215 #ifdef MEMORY_SANITIZER
1216 /*emit_code_for_msan=*/true
1217 #else
1218 /*emit_code_for_msan=*/false
1219 #endif
1220 );
1221
1222 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
1223
1224 for (ComputationToEmit subcomputation :
1225 SubcomputationEmissionOrder(entry_computation)) {
1226 if (subcomputation.computation->IsFusionComputation()) {
1227 continue;
1228 }
1229 TF_RETURN_IF_ERROR(
1230 ir_emitter
1231 .EmitComputation(
1232 subcomputation.computation,
1233 subcomputation.computation->name(),
1234 /*is_top_level_computation=*/false,
1235 schedule.sequence(subcomputation.computation).instructions(),
1236 subcomputation.allow_reassociation)
1237 .status());
1238 }
1239 std::string function_name_prefix = entry_computation->name().empty()
1240 ? "__compute"
1241 : entry_computation->name();
1242 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
1243 ir_emitter.EmitComputation(
1244 entry_computation, function_name_prefix,
1245 /*is_top_level_computation=*/true,
1246 schedule.sequence(entry_computation).instructions(),
1247 /*allow_reassociation=*/false));
1248
1249 function_name = [&]() {
1250 llvm::SmallVector<char, 40> function_name_vector;
1251 llvm::Mangler::getNameWithPrefix(function_name_vector,
1252 entry_function->getName(),
1253 (*jit)->data_layout());
1254 return std::string(function_name_vector.begin(),
1255 function_name_vector.end());
1256 }();
1257 }
1258
1259 std::string ir_module_string;
1260 if (embed_ir_in_executable) {
1261 ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
1262 }
1263
1264 TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
1265
1266 // JIT compile the LLVM IR module to in-memory machine code.
1267 llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
1268 std::move(llvm_context));
1269 cantFail((*jit)->AddModule(std::move(thread_safe_module)));
1270
1271 auto cpu_executable = std::make_unique<CpuExecutable>(
1272 std::move(*jit), std::move(assignment), std::move(module), function_name,
1273 std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map));
1274
1275 if (embed_ir_in_executable) {
1276 cpu_executable->set_ir_module_string(ir_module_string);
1277 }
1278
1279 // Dump computation proto state and buffer assignment for debug and test, if
1280 // dump or embed_ir_in_executable is enabled.
1281 if (embed_ir_in_executable ||
1282 DumpingEnabledForHloModule(cpu_executable->module())) {
1283 auto hlo_proto = std::make_unique<HloProto>();
1284 *hlo_proto->mutable_hlo_module() = cpu_executable->module().ToProto();
1285 *hlo_proto->mutable_buffer_assignment() =
1286 cpu_executable->buffer_assignment().ToProto();
1287 cpu_executable->set_hlo_proto(std::move(hlo_proto));
1288 }
1289
1290 return cpu_executable;
1291 }
1292
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)1293 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
1294 std::unique_ptr<HloModule> module,
1295 [[maybe_unused]] se::StreamExecutor* stream_exec,
1296 [[maybe_unused]] const CompileOptions& options) {
1297 VLOG(1) << "Compiling: " << module->name();
1298 XLA_SCOPED_LOGGING_TIMER(
1299 absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
1300 std::string slow_compilation_msg =
1301 absl::StrCat("Compiling module ", module->name());
1302 auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
1303
1304 absl::call_once(llvm_command_line_options_initialized,
1305 &InitializeLLVMCommandLineOptions, module->config());
1306
1307 std::unique_ptr<CpuExecutable> cpu_executable;
1308 TF_ASSIGN_OR_RETURN(cpu_executable,
1309 CompileLegacyCpuExecutable(std::move(module)));
1310
1311 cpu_executable->set_debug_info(
1312 cpu_executable->buffer_assignment().GetStats().ToString());
1313 VLOG(1) << "Compilation finished";
1314 return std::unique_ptr<Executable>(std::move(cpu_executable));
1315 }
1316
1317 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)1318 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
1319 const AotCompilationOptions& aot_options) {
1320 TF_RET_CHECK(!module_group->empty());
1321 std::vector<std::unique_ptr<HloModule>> modules =
1322 module_group->ConsumeModules();
1323
1324 absl::call_once(llvm_command_line_options_initialized,
1325 &InitializeLLVMCommandLineOptions, modules[0]->config());
1326
1327 // We can pass just one llvm::TargetOptions when we compile the LLVM module,
1328 // so we bail if the configs have conflicting flags. At the moment, the only
1329 // flags that need to be consistent are for fast-math.
1330 for (const auto& fn_and_name :
1331 {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math,
1332 "xla_cpu_enable_fast_math"),
1333 std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs,
1334 "xla_cpu_fast_math_honor_infs"),
1335 std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans,
1336 "xla_cpu_fast_math_honor_nans")}) {
1337 // This only works because each of the method pointers above returns a bool.
1338 // Otherwise we'd have to do some template magic.
1339 const auto& field_method_ptr = fn_and_name.first;
1340 const auto& field_name = fn_and_name.second;
1341 bool first_module_val =
1342 (modules[0]->config().debug_options().*field_method_ptr)();
1343 for (int64_t i = 0; i < modules.size(); ++i) {
1344 bool cur_module_val =
1345 (modules[i]->config().debug_options().*field_method_ptr)();
1346 if (first_module_val != cur_module_val) {
1347 return InvalidArgument(
1348 "All HLO module configs must have the same value for %s, but "
1349 "module 0 and %d have different values (%d vs %d).",
1350 field_name, i, first_module_val, cur_module_val);
1351 }
1352 }
1353 }
1354
1355 if (aot_options.PlatformId() != se::host::kHostPlatformId) {
1356 return InvalidArgument("Incompatible AOT compilation platform");
1357 }
1358 const CpuAotCompilationOptions& options =
1359 static_cast<const CpuAotCompilationOptions&>(aot_options);
1360 llvm::Triple triple(llvm::Triple::normalize(options.triple()));
1361 std::string error;
1362 const llvm::Target* target =
1363 llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
1364 if (target == nullptr) {
1365 return InternalError("TargetRegistry::lookupTarget failed: %s", error);
1366 }
1367
1368 llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
1369 llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
1370 llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
1371 switch (options.relocation_model()) {
1372 case CpuAotCompilationOptions::RelocationModel::Static:
1373 reloc_model = llvm::Reloc::Static;
1374 pic_level = llvm::PICLevel::NotPIC;
1375 pie_level = llvm::PIELevel::Default;
1376 break;
1377 case CpuAotCompilationOptions::RelocationModel::SmallPic:
1378 reloc_model = llvm::Reloc::PIC_;
1379 pic_level = llvm::PICLevel::SmallPIC;
1380 pie_level = llvm::PIELevel::Default;
1381 break;
1382 case CpuAotCompilationOptions::RelocationModel::BigPic:
1383 reloc_model = llvm::Reloc::PIC_;
1384 pic_level = llvm::PICLevel::BigPIC;
1385 pie_level = llvm::PIELevel::Default;
1386 break;
1387 case CpuAotCompilationOptions::RelocationModel::SmallPie:
1388 reloc_model = llvm::Reloc::PIC_;
1389 pic_level = llvm::PICLevel::SmallPIC;
1390 pie_level = llvm::PIELevel::Small;
1391 break;
1392 case CpuAotCompilationOptions::RelocationModel::BigPie:
1393 reloc_model = llvm::Reloc::PIC_;
1394 pic_level = llvm::PICLevel::BigPIC;
1395 pie_level = llvm::PIELevel::Large;
1396 break;
1397 }
1398 llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
1399 std::unique_ptr<llvm::TargetMachine> target_machine =
1400 absl::WrapUnique(target->createTargetMachine(
1401 triple.getTriple(), options.cpu_name(), options.features(),
1402 CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
1403 opt_level));
1404
1405 // Compile must be thread-safe so create a new LLVM context for the module.
1406 mlir::MLIRContext mlir_context;
1407 LoadMLIRDialects(mlir_context);
1408 llvm::LLVMContext llvm_context;
1409 std::unique_ptr<llvm::Module> llvm_module;
1410
1411 std::vector<std::unique_ptr<AotCompilationResult>> results;
1412 for (size_t i = 0; i < modules.size(); ++i) {
1413 HloModule* module = modules[i].get();
1414 VLOG(1) << "Compiling ahead-of-time: " << module->name();
1415
1416 TF_RETURN_IF_ERROR(
1417 RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get(),
1418 /*is_mlir_compile=*/options.use_mlir_hlo_lowering()));
1419
1420 TF_ASSIGN_OR_RETURN(HloSchedule schedule,
1421 ScheduleModule(module, BufferSizeBytesFunction()));
1422
1423 // Run buffer analysis on the HLO graph. This analysis figures out which
1424 // temporary buffers are required to run the computation.
1425 TF_ASSIGN_OR_RETURN(
1426 std::unique_ptr<BufferAssignment> assignment,
1427 BufferAssigner::Run(module,
1428 std::make_unique<SequentialHloOrdering>(schedule),
1429 BufferSizeBytesFunction(), memory_alignment,
1430 /*allocate_buffers_for_constants=*/true));
1431 // BufferAssignment::ToString() includes a header, so no need for us to
1432 // print one ourselves.
1433 if (DumpingEnabledForHloModule(*module)) {
1434 DumpToFileInDirOrStdout(*module, "", "buffer_assignment",
1435 assignment->ToString());
1436 }
1437 DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
1438
1439 absl::flat_hash_map<const HloInstruction*, int64_t>
1440 instruction_to_profile_idx;
1441 absl::flat_hash_map<const HloComputation*, int64_t>
1442 computation_to_profile_idx;
1443 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
1444 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
1445
1446 if (module->config().hlo_profiling_enabled()) {
1447 TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
1448 *module, &instruction_to_profile_idx, &computation_to_profile_idx,
1449 &hlo_profile_index_map, &hlo_profile_printer_data));
1450 }
1451
1452 LLVMTargetMachineFeatures target_machine_features(target_machine.get());
1453 std::vector<BufferInfo> buffer_infos =
1454 CreateBufferInfosFromBufferAssignment(*assignment);
1455 HloComputation* computation = module->entry_computation();
1456
1457 if (UseMlirHloLowering(options.use_mlir_hlo_lowering(), module)) {
1458 TF_ASSIGN_OR_RETURN(
1459 auto mlir_module,
1460 createMLIRModule(module, mlir_context, assignment.get()));
1461 TF_RETURN_IF_ERROR(LowerMLIRModule(*mlir_module, mlir_context));
1462
1463 llvm::cast<mlir::LLVM::LLVMFuncOp>(
1464 mlir_module->lookupSymbol("main_xla_framework"))
1465 .setName(options.entry_point_name());
1466
1467 llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, llvm_context);
1468 // Set missing information
1469 llvm_module->setDataLayout(target_machine->createDataLayout());
1470 llvm_module->setTargetTriple(triple.getTriple());
1471 if (pic_level != llvm::PICLevel::NotPIC) {
1472 llvm_module->setPICLevel(pic_level);
1473 }
1474 if (pie_level != llvm::PIELevel::Default) {
1475 llvm_module->setPIELevel(pie_level);
1476 }
1477 } else {
1478 // Set required information before emitting IR
1479 llvm_module =
1480 std::make_unique<llvm::Module>("__compute_module", llvm_context);
1481 llvm_module->setDataLayout(target_machine->createDataLayout());
1482 llvm_module->setTargetTriple(triple.getTriple());
1483 if (pic_level != llvm::PICLevel::NotPIC) {
1484 llvm_module->setPICLevel(pic_level);
1485 }
1486 if (pie_level != llvm::PIELevel::Default) {
1487 llvm_module->setPIELevel(pie_level);
1488 }
1489 IrEmitter ir_emitter(
1490 &mlir_context, *module, *assignment, llvm_module.get(),
1491 std::move(instruction_to_profile_idx),
1492 std::move(computation_to_profile_idx),
1493 ModuleComputationsTransitivelyContainCustomCall(*module),
1494 &target_machine_features,
1495 // TODO(b/66051036): Run full msan for AOT.
1496 /*emit_code_for_msan=*/false);
1497
1498 TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
1499
1500 for (ComputationToEmit subcomputation :
1501 SubcomputationEmissionOrder(computation)) {
1502 if (subcomputation.computation->IsFusionComputation()) {
1503 continue;
1504 }
1505 TF_RETURN_IF_ERROR(
1506 ir_emitter
1507 .EmitComputation(subcomputation.computation,
1508 subcomputation.computation->name(),
1509 /*is_top_level_computation=*/false,
1510 schedule.sequence(subcomputation.computation)
1511 .instructions(),
1512 subcomputation.allow_reassociation)
1513 .status());
1514 }
1515 const std::string& entry_point_name = options.entry_point_name();
1516 TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
1517 ir_emitter.EmitComputation(
1518 computation, entry_point_name,
1519 /*is_top_level_computation=*/true,
1520 schedule.sequence(computation).instructions(),
1521 /*allow_reassociation=*/false));
1522
1523 CHECK(entry_function->getName() == entry_point_name);
1524 }
1525
1526 ModuleHook pre_optimization_ir_hook;
1527 ModuleHook post_optimization_ir_hook;
1528 std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
1529 GetIRModuleHooks(*module, user_pre_optimization_hook_,
1530 user_post_optimization_hook_);
1531
1532 // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run
1533 // the pre-optimization IR dump hook before returning.
1534 {
1535 Status verify_status = VerifyLlvmModule(*llvm_module);
1536 if (!verify_status.ok() && pre_optimization_ir_hook) {
1537 pre_optimization_ir_hook(*llvm_module);
1538 }
1539 TF_RETURN_IF_ERROR(verify_status);
1540 }
1541
1542 auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
1543 if (!DumpingEnabledForHloModule(*module)) {
1544 return;
1545 }
1546 DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
1547 absl::string_view(obj_file.getData().data(),
1548 obj_file.getData().size()));
1549 };
1550
1551 CompilerFunctor compiler_functor(
1552 target_machine.get(), opt_level,
1553 options::OptimizeForSizeRequested(module->config()),
1554 module->config().debug_options().xla_llvm_disable_expensive_passes(),
1555 llvm_ir::GetCpuFastMathFlags(module->config()),
1556 pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook,
1557 aot_options.sanitize_dataflow(),
1558 aot_options.sanitize_abilists_dataflow());
1559 std::unique_ptr<llvm::MemoryBuffer> object_file =
1560 cantFail(compiler_functor(*llvm_module));
1561 ObjectFileData object_file_data(object_file->getBufferStart(),
1562 object_file->getBufferEnd());
1563
1564 TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
1565 assignment->GetUniqueTopLevelOutputSlice());
1566
1567 results.emplace_back(std::make_unique<CpuAotCompilationResult>(
1568 std::move(object_file_data), std::move(buffer_infos),
1569 result_slice.index(), std::move(hlo_profile_printer_data)));
1570 }
1571
1572 VLOG(1) << "Compilation finished";
1573 return std::move(results);
1574 }
1575
PlatformId() const1576 se::Platform::Id CpuCompiler::PlatformId() const {
1577 return se::host::kHostPlatformId;
1578 }
1579
ShapeSizeBytesFunction() const1580 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
1581 return CpuExecutable::ShapeSizeBytes;
1582 }
1583
1584 } // namespace cpu
1585 } // namespace xla
1586
InitModule()1587 static bool InitModule() {
1588 xla::Compiler::RegisterCompilerFactory(
1589 stream_executor::host::kHostPlatformId,
1590 []() { return std::make_unique<xla::cpu::CpuCompiler>(); });
1591 return true;
1592 }
1593 static bool module_initialized = InitModule();
1594