xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
17 
18 #include <stddef.h>
19 #include <string.h>
20 
21 #include <functional>
22 #include <map>
23 #include <memory>
24 #include <stack>
25 #include <string>
26 #include <tuple>
27 #include <utility>
28 #include <vector>
29 
30 // IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
31 // IWYU pragma: no_include "llvm/Config/Targets.def.inc"
32 
33 #include "absl/base/call_once.h"
34 #include "absl/container/flat_hash_map.h"
35 #include "absl/strings/str_cat.h"
36 #include "llvm/ADT/ArrayRef.h"
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/ADT/StringRef.h"
39 #include "llvm/ADT/Triple.h"
40 #include "llvm/IR/Function.h"
41 #include "llvm/IR/LLVMContext.h"
42 #include "llvm/IR/Mangler.h"
43 #include "llvm/IR/Module.h"
44 #include "llvm/IR/Verifier.h"
45 #include "llvm/MC/TargetRegistry.h"
46 #include "llvm/Object/ObjectFile.h"
47 #include "llvm/Support/CodeGen.h"
48 #include "llvm/Support/CommandLine.h"
49 #include "llvm/Support/Error.h"
50 #include "llvm/Support/TargetSelect.h"
51 #include "llvm/Target/TargetMachine.h"
52 #include "llvm/Target/TargetOptions.h"
53 #include "mlir/Conversion/AffineToStandard/AffineToStandard.h"  // from @llvm-project
54 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"  // from @llvm-project
55 #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h"  // from @llvm-project
56 #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h"  // from @llvm-project
57 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"  // from @llvm-project
58 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"  // from @llvm-project
59 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"  // from @llvm-project
60 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"  // from @llvm-project
61 #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"  // from @llvm-project
62 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"  // from @llvm-project
63 #include "mlir/Conversion/TensorToLinalg/TensorToLinalgPass.h"  // from @llvm-project
64 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"  // from @llvm-project
65 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
66 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
67 #include "mlir/Dialect/Arithmetic/Transforms/Passes.h"  // from @llvm-project
68 #include "mlir/Dialect/Bufferization/Transforms/Passes.h"  // from @llvm-project
69 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
70 #include "mlir/Dialect/Func/Transforms/Passes.h"  // from @llvm-project
71 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"  // from @llvm-project
72 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"  // from @llvm-project
73 #include "mlir/Dialect/Linalg/IR/Linalg.h"  // from @llvm-project
74 #include "mlir/Dialect/Linalg/Passes.h"  // from @llvm-project
75 #include "mlir/Dialect/MemRef/Transforms/Passes.h"  // from @llvm-project
76 #include "mlir/Dialect/SCF/IR/SCF.h"  // from @llvm-project
77 #include "mlir/Dialect/Shape/Transforms/Passes.h"  // from @llvm-project
78 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
79 #include "mlir/Dialect/Vector/IR/VectorOps.h"  // from @llvm-project
80 #include "mlir/IR/Builders.h"  // from @llvm-project
81 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
82 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
83 #include "mlir/IR/ImplicitLocOpBuilder.h"  // from @llvm-project
84 #include "mlir/Pass/PassManager.h"  // from @llvm-project
85 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"  // from @llvm-project
86 #include "mlir/Target/LLVMIR/Export.h"  // from @llvm-project
87 #include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"  // from @llvm-project
88 #include "mlir/Transforms/Passes.h"  // from @llvm-project
89 #include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
90 #include "tensorflow/compiler/mlir/xla/ir/xla_framework.h"
91 #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
92 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
93 #include "tensorflow/compiler/xla/literal.h"
94 #include "tensorflow/compiler/xla/map_util.h"
95 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/gml_st/transforms/passes.h"
96 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
97 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Transforms/passes.h"
98 #include "tensorflow/compiler/xla/protobuf_util.h"
99 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
100 #include "tensorflow/compiler/xla/service/all_gather_decomposer.h"
101 #include "tensorflow/compiler/xla/service/all_to_all_decomposer.h"
102 #include "tensorflow/compiler/xla/service/batch_dot_simplification.h"
103 #include "tensorflow/compiler/xla/service/batchnorm_expander.h"
104 #include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
105 #include "tensorflow/compiler/xla/service/bitcast_dtypes_expander.h"
106 #include "tensorflow/compiler/xla/service/broadcast_canonicalizer.h"
107 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
108 #include "tensorflow/compiler/xla/service/call_inliner.h"
109 #include "tensorflow/compiler/xla/service/change_op_data_type.h"
110 #include "tensorflow/compiler/xla/service/cholesky_expander.h"
111 #include "tensorflow/compiler/xla/service/comparison_expander.h"
112 #include "tensorflow/compiler/xla/service/conditional_canonicalizer.h"
113 #include "tensorflow/compiler/xla/service/conditional_simplifier.h"
114 #include "tensorflow/compiler/xla/service/conditional_to_select.h"
115 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
116 #include "tensorflow/compiler/xla/service/copy_insertion.h"
117 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
118 #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
119 #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
120 #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
121 #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
122 #include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
123 #include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
124 #include "tensorflow/compiler/xla/service/cpu/cpu_shape_verifier.h"
125 #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
126 #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
127 #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
128 #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
129 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
130 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
131 #include "tensorflow/compiler/xla/service/dot_decomposer.h"
132 #include "tensorflow/compiler/xla/service/dump.h"
133 #include "tensorflow/compiler/xla/service/dynamic_dimension_simplifier.h"
134 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
135 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
136 #include "tensorflow/compiler/xla/service/eigh_expander.h"
137 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
138 #include "tensorflow/compiler/xla/service/gather_expander.h"
139 #include "tensorflow/compiler/xla/service/hlo.pb.h"
140 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
141 #include "tensorflow/compiler/xla/service/hlo_computation.h"
142 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
143 #include "tensorflow/compiler/xla/service/hlo_cse.h"
144 #include "tensorflow/compiler/xla/service/hlo_dce.h"
145 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
146 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
147 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
148 #include "tensorflow/compiler/xla/service/hlo_module.h"
149 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
150 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
151 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
152 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
153 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
154 #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
155 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
156 #include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
157 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
158 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_command_line_options.h"
159 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
160 #include "tensorflow/compiler/xla/service/logistic_expander.h"
161 #include "tensorflow/compiler/xla/service/map_inliner.h"
162 #include "tensorflow/compiler/xla/service/operand_upcaster.h"
163 #include "tensorflow/compiler/xla/service/optimization_barrier_expander.h"
164 #include "tensorflow/compiler/xla/service/qr_expander.h"
165 #include "tensorflow/compiler/xla/service/reduce_decomposer.h"
166 #include "tensorflow/compiler/xla/service/reduce_scatter_decomposer.h"
167 #include "tensorflow/compiler/xla/service/reshape_decomposer.h"
168 #include "tensorflow/compiler/xla/service/reshape_mover.h"
169 #include "tensorflow/compiler/xla/service/result_caster.h"
170 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
171 #include "tensorflow/compiler/xla/service/rng_expander.h"
172 #include "tensorflow/compiler/xla/service/scatter_expander.h"
173 #include "tensorflow/compiler/xla/service/select_and_scatter_expander.h"
174 #include "tensorflow/compiler/xla/service/sharding_propagation.h"
175 #include "tensorflow/compiler/xla/service/sharding_remover.h"
176 #include "tensorflow/compiler/xla/service/slice_sinker.h"
177 #include "tensorflow/compiler/xla/service/slow_operation_alarm.h"
178 #include "tensorflow/compiler/xla/service/sort_simplifier.h"
179 #include "tensorflow/compiler/xla/service/spmd/stateful_rng_spmd_partitioner.h"
180 #include "tensorflow/compiler/xla/service/topk_rewriter.h"
181 #include "tensorflow/compiler/xla/service/transpose_folding.h"
182 #include "tensorflow/compiler/xla/service/tree_reduction_rewriter.h"
183 #include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
184 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
185 #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
186 #include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
187 #include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
188 #include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h"
189 #include "tensorflow/compiler/xla/status_macros.h"
190 #include "tensorflow/compiler/xla/statusor.h"
191 #include "tensorflow/compiler/xla/types.h"
192 #include "tensorflow/compiler/xla/util.h"
193 #include "tensorflow/compiler/xla/xla_data.pb.h"
194 #include "tensorflow/core/platform/errors.h"
195 #include "tensorflow/core/platform/status.h"
196 #include "tensorflow/core/protobuf/error_codes.pb.h"
197 
198 namespace {
199 
200 // We need to explicitly load all the dialects we will involved in emitting the
201 // IR. This is only needed because of how MLIR is bolted into XLA and does not
202 // make use of the MLIR infrastructure (like using a proper pass pipeline).
203 // Hopefully this will all go away at some point in favor of a better
204 // integration.
LoadMLIRDialects(mlir::MLIRContext & context)205 void LoadMLIRDialects(mlir::MLIRContext& context) {
206   context.loadDialect<mlir::arith::ArithmeticDialect,
207                       mlir::linalg::LinalgDialect, mlir::scf::SCFDialect,
208                       mlir::vector::VectorDialect, mlir::func::FuncDialect,
209                       mlir::AffineDialect, mlir::tensor::TensorDialect,
210                       mlir::xla_framework::XLAFrameworkDialect>();
211   mlir::registerLLVMDialectTranslation(context);
212 }
213 
214 }  // namespace
215 
216 namespace xla {
217 
218 namespace {
219 
UseMlirHloLowering(bool use_mlir,HloModule * module)220 bool UseMlirHloLowering(bool use_mlir, HloModule* module) {
221   // TODO(tpopp): The prototype currently does not properly handle constant
222   // buffers that are handled by the runtime's buffer assignmen.
223   return use_mlir &&
224          module->entry_computation()->root_instruction()->opcode() !=
225              HloOpcode::kConstant;
226 }
227 
228 // For each computation in the module, determines whether that computation
229 // calls a custom-call function, either directly or indirectly (e.g. because it
230 // calls another computation that does).
231 absl::flat_hash_map<const HloComputation*, bool>
ModuleComputationsTransitivelyContainCustomCall(const HloModule & module)232 ModuleComputationsTransitivelyContainCustomCall(const HloModule& module) {
233   absl::flat_hash_map<const HloComputation*, bool> custom_call_map;
234   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(&module);
235 
236   // Can never fail because we always return an OK status from the visitor.
237   TF_CHECK_OK(call_graph->VisitNodes([&custom_call_map](
238                                          const CallGraphNode& node) {
239     const HloComputation* computation = node.computation();
240 
241     for (const HloInstruction* instruction : computation->instructions()) {
242       // The computation contains a custom-call instruction directly.
243       if (DynCast<HloCustomCallInstruction>(instruction)) {
244         custom_call_map[computation] = true;
245         return OkStatus();
246       }
247       // The computation calls something that contains a custom-call
248       // instruction (directly or indirectly). This lookup relies on the call
249       // graph traversing callees before callers, so that the map is always
250       // populated for all callees at this point.
251       for (const HloComputation* callee : instruction->called_computations()) {
252         bool callee_contains_custom_call = FindOrDie(custom_call_map, callee);
253         if (callee_contains_custom_call) {
254           custom_call_map[computation] = true;
255           return OkStatus();
256         }
257       }
258     }
259 
260     custom_call_map[computation] = false;
261     return OkStatus();
262   }));
263 
264   return custom_call_map;
265 }
266 
267 }  // namespace
268 
269 namespace cpu {
270 using BufferInfo = cpu_function_runtime::BufferInfo;
271 
CpuAotCompilationOptions(std::string triple,std::string cpu_name,std::string features,std::string entry_point_name,RelocationModel relocation_model)272 CpuAotCompilationOptions::CpuAotCompilationOptions(
273     std::string triple, std::string cpu_name, std::string features,
274     std::string entry_point_name, RelocationModel relocation_model)
275     : triple_(std::move(triple)),
276       cpu_name_(std::move(cpu_name)),
277       features_(std::move(features)),
278       entry_point_name_(std::move(entry_point_name)),
279       relocation_model_(relocation_model) {}
280 
281 CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
282 
PlatformId() const283 se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
284   return se::host::kHostPlatformId;
285 }
286 
CpuAotCompilationResult(ObjectFileData object_file_data,std::vector<BufferInfo> buffer_infos,int64_t result_buffer_index,std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)287 CpuAotCompilationResult::CpuAotCompilationResult(
288     ObjectFileData object_file_data, std::vector<BufferInfo> buffer_infos,
289     int64_t result_buffer_index,
290     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data)
291     : object_file_data_(std::move(object_file_data)),
292       buffer_infos_(std::move(buffer_infos)),
293       result_buffer_index_(result_buffer_index),
294       hlo_profile_printer_data_(std::move(hlo_profile_printer_data)) {}
295 
296 CpuAotCompilationResult::~CpuAotCompilationResult() = default;
297 
CpuCompiler()298 CpuCompiler::CpuCompiler() {
299   // Initialize LLVM the first time the CpuCompiler is initialized.
300   static bool llvm_initialized = []() {
301     InitializeLLVMTarget();
302     return true;
303   }();
304   (void)llvm_initialized;
305 }
306 
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_execs,const CompileOptions & options)307 StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
308     std::unique_ptr<HloModuleGroup> module_group,
309     std::vector<std::vector<se::StreamExecutor*>> stream_execs,
310     const CompileOptions& options) {
311   for (const std::vector<se::StreamExecutor*>& se_vector : stream_execs) {
312     if (se_vector.size() != 1) {
313       return Unimplemented(
314           "Model partitioning not implemented for the CPU compiler");
315     }
316   }
317   return LLVMCompiler::Compile(std::move(module_group), stream_execs, options);
318 }
319 
InitializeLLVMTarget()320 /* static */ void CpuCompiler::InitializeLLVMTarget() {
321   // Initialize LLVM's MC layer for the native target.
322   llvm::InitializeNativeTarget();
323   llvm::InitializeNativeTargetAsmPrinter();
324 }
325 
326 namespace {
327 
328 // LLVM makes certain options configurable only through its command-line
329 // options; it provide the ParseCommandLineOptions function that lets us set
330 // flags at runtime. However, since these flags are global we want to avoid
331 // multiple invocations of the LLVM compilation pipeline with a different set of
332 // flags. Therefore, we only pass command-line flags to LLVM once, before the
333 // first module is compiled.
334 absl::once_flag llvm_command_line_options_initialized;
335 
336 // This visitor records which HLO instructions should have profiling information
337 // recorded.
338 class CollectProfileCandidates : public DfsHloVisitorWithDefault {
339  public:
340   static StatusOr<absl::flat_hash_map<const HloInstruction*, int64_t>>
GetCandidatesForComputation(const HloComputation & computation,const absl::flat_hash_map<const HloInstruction *,int64_t> & assigned_indices)341   GetCandidatesForComputation(
342       const HloComputation& computation,
343       const absl::flat_hash_map<const HloInstruction*, int64_t>&
344           assigned_indices) {
345     absl::flat_hash_map<const HloInstruction*, int64_t> hlo_to_profile_idx;
346     CollectProfileCandidates profile_candidates_for_computation(
347         &hlo_to_profile_idx, assigned_indices);
348     TF_RETURN_IF_ERROR(computation.Accept(&profile_candidates_for_computation));
349     return hlo_to_profile_idx;
350   }
351 
352  private:
CollectProfileCandidates(absl::flat_hash_map<const HloInstruction *,int64_t> * hlo_to_profile_idx,const absl::flat_hash_map<const HloInstruction *,int64_t> & assigned_indices)353   CollectProfileCandidates(
354       absl::flat_hash_map<const HloInstruction*, int64_t>* hlo_to_profile_idx,
355       const absl::flat_hash_map<const HloInstruction*, int64_t>&
356           assigned_indices)
357       : hlo_to_profile_idx_(hlo_to_profile_idx),
358         assigned_indices_(assigned_indices) {}
359 
DefaultAction(HloInstruction * hlo_instruction)360   Status DefaultAction(HloInstruction* hlo_instruction) override {
361     hlo_to_profile_idx_->insert(
362         {hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
363     return OkStatus();
364   }
365 
HandleCall(HloInstruction * call)366   Status HandleCall(HloInstruction* call) override {
367     TF_RETURN_IF_ERROR(DefaultAction(call));
368     CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
369                                                  assigned_indices_);
370     TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
371     return OkStatus();
372   }
373   // Recurse into "conditional" so we can profile inside of it.
HandleConditional(HloInstruction * conditional)374   Status HandleConditional(HloInstruction* conditional) override {
375     TF_RETURN_IF_ERROR(DefaultAction(conditional));
376 
377     CollectProfileCandidates candidates_for_true(hlo_to_profile_idx_,
378                                                  assigned_indices_);
379     TF_RETURN_IF_ERROR(
380         conditional->true_computation()->Accept(&candidates_for_true));
381 
382     CollectProfileCandidates candidates_for_false(hlo_to_profile_idx_,
383                                                   assigned_indices_);
384     TF_RETURN_IF_ERROR(
385         conditional->false_computation()->Accept(&candidates_for_false));
386 
387     return OkStatus();
388   }
389 
390   // Skip constants, there is nothing to profile.
HandleConstant(HloInstruction *)391   Status HandleConstant(HloInstruction*) override { return OkStatus(); }
392   // Skip parameters, they are a simple load.
HandleParameter(HloInstruction *)393   Status HandleParameter(HloInstruction*) override { return OkStatus(); }
394   // It is important to recurse for "while" or else we risk overly coarse
395   // profiling information.
HandleWhile(HloInstruction * xla_while)396   Status HandleWhile(HloInstruction* xla_while) override {
397     TF_RETURN_IF_ERROR(DefaultAction(xla_while));
398 
399     CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
400                                                       assigned_indices_);
401     TF_RETURN_IF_ERROR(
402         xla_while->while_condition()->Accept(&candidates_for_condition));
403 
404     CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
405                                                  assigned_indices_);
406     TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));
407 
408     return OkStatus();
409   }
410 
411   absl::flat_hash_map<const HloInstruction*, int64_t>* hlo_to_profile_idx_;
412   const absl::flat_hash_map<const HloInstruction*, int64_t>& assigned_indices_;
413 };
414 
415 // Adds the HloVerifier for CPU to the given pipeline.
AddHloVerifier(HloPassPipeline * pipeline,HloVerifierOpts && opts={},bool debug_only=false)416 void AddHloVerifier(HloPassPipeline* pipeline, HloVerifierOpts&& opts = {},
417                     bool debug_only = false) {
418   std::unique_ptr<TargetVerifierMetadata> verifier_metadata =
419       std::make_unique<CpuVerifierMetadata>(std::move(opts));
420   if (debug_only) {
421     pipeline->AddInvariantCheckerDebug<HloVerifier>(
422         std::move(verifier_metadata), "hlo verifier (debug)");
423   } else {
424     pipeline->AddInvariantChecker<HloVerifier>(std::move(verifier_metadata),
425                                                "hlo verifier");
426   }
427 }
428 
429 }  // namespace
430 
RunHloPassesThroughLayoutAssn(HloModule * module,bool,LLVMTargetMachineFeatures * target_machine_features,bool is_mlir_compile)431 Status CpuCompiler::RunHloPassesThroughLayoutAssn(
432     HloModule* module, bool /*is_aot_compile*/,
433     LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) {
434   if (module->config().use_spmd_partitioning()) {
435     HloPassPipeline spmd_pipeline("spmd-partitioner");
436     const int64_t num_partitions = module->config().num_partitions();
437     if (num_partitions > 1) {
438       // Run some IR cleanup passes before running the SPMD partitioning
439       // passes.
440       AddHloVerifier(&spmd_pipeline);
441       spmd_pipeline.AddPass<CallInliner>();
442       spmd_pipeline.AddPass<ZeroSizedHloElimination>();
443       spmd_pipeline.AddPass<ConditionalCanonicalizer>();
444 
445       spmd_pipeline.AddPass<ShardingPropagation>(
446           /*is_spmd=*/true, /*propagate_metadata=*/false,
447           module->config().allow_spmd_sharding_propagation_to_output());
448       spmd_pipeline.AddPass<spmd::StatefulRngSpmdPartitioner>(
449           num_partitions, module->config().replica_count());
450     } else {
451       // Remove redundant sharding ops when partition_count == 1.
452       spmd_pipeline.AddPass<ShardingRemover>();
453       spmd_pipeline.AddPass<HloDCE>();
454     }
455     TF_RETURN_IF_ERROR(spmd_pipeline.Run(module).status());
456   }
457 
458   HloPassPipeline pipeline("HLO passes through layout assignment");
459   AddHloVerifier(&pipeline);
460 
461   pipeline.AddPass<OperandUpcaster>();
462   pipeline.AddPass<ResultCaster>();
463 
464   // Expand random number generation.
465   pipeline.AddPass<RngExpander>();
466   pipeline.AddPass<RngBitGeneratorExpander>(RandomAlgorithm::RNG_PHILOX);
467 
468   // Remove zero-sized HLO from the input so that other passes don't have to
469   // handle it.
470   pipeline.AddPass<ZeroSizedHloElimination>();
471 
472   pipeline.AddPass<DynamicIndexSplitter>();
473 
474   pipeline.AddPass<ConditionalToSelect>();
475   pipeline.AddPass<MapInliner>();
476 
477   pipeline.AddPass<ComparisonExpander>();
478   pipeline.AddPass<CholeskyExpander>();
479   pipeline.AddPass<QrExpander>();
480   pipeline.AddPass<EighExpander>();
481   pipeline.AddPass<TriangularSolveExpander>();
482   pipeline.AddPass<AllGatherDecomposer>();
483   pipeline.AddPass<AllToAllDecomposer>();
484   pipeline.AddPass<ReduceScatterDecomposer>();
485 
486   // Inline computations with a single call site.
487   pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
488   pipeline.AddPass<BatchDotSimplification>();
489   pipeline.AddPass<DotDecomposer>();
490   // Convert BF16 operations to F32 operations so that the CPU backend can
491   // support BF16 operations without directly implementing a BF16 lowering for
492   // most ops.
493   BFloat16Support bf16;
494   pipeline.AddPass<BFloat16Normalization>(&bf16);
495   // After canonicalization, there may be more batch dots that can be
496   // simplified.
497   pipeline.AddPass<BatchDotSimplification>();
498   auto cost_model = [](HloInstruction* conv) {
499     // We need a cost model for CPUs. Currently, do nothing.
500     return false;
501   };
502   pipeline.AddPass<ConvolutionGroupConverter>(
503       /*should_expand=*/[](HloInstruction* conv) { return true; }, cost_model,
504       /*convert_batch_groups_only=*/true);
505   auto feature_group_should_expand = [](HloInstruction* conv) {
506     switch (conv->shape().element_type()) {
507       case F16:
508       case F32:
509         return false;
510       default:
511         return true;
512     }
513   };
514   pipeline.AddPass<ConvolutionGroupConverter>(
515       feature_group_should_expand, cost_model,
516       /*convert_batch_groups_only=*/false);
517   pipeline.AddPass<BatchNormExpander>(
518       /*rewrite_training_op=*/true,
519       /*rewrite_inference_op=*/true,
520       /*rewrite_grad_op=*/true);
521   pipeline.AddPass<LogisticExpander>(
522       /*expansion_type=*/LogisticExpansionType::kExp);
523   pipeline.AddPass<ConditionalCanonicalizer>();
524   pipeline.AddPass<DynamicDimensionSimplifier>();
525   auto dynamic_padder_options = DynamicPadderOptions();
526   dynamic_padder_options.shape_check_mode =
527       DynamicDimensionInference::ShapeCheckMode::kCompileTime;
528   pipeline.AddPass<DynamicPadder>(dynamic_padder_options);
529   pipeline.AddPass<SelectAndScatterExpander>();
530   pipeline.AddPass<ScatterExpander>(ScatterExpander::kEliminateAllScatters);
531   pipeline.AddPass<ConvCanonicalization>(target_machine_features);
532 
533   // Run fp16 dots/convs in fp32 and then downcast the result to fp16.
534   // Justification:
535   //
536   //   - This is significantly faster on our CPUs today than true fp16.
537   //   - It's numerically more accurate.  (Granted, this is not always
538   //     desirable, thus the ability to disable this functionality.)
539   //   - It matches more closely the GPU's behavior on fp16 dot/conv, where
540   //     accumulation happens in f32.
541   if (!module->config().debug_options().xla_cpu_strict_dot_conv_math()) {
542     pipeline.AddPass<ChangeOpDataType>(
543         F16, F32, [](const HloInstruction* instr) {
544           return instr->opcode() == HloOpcode::kDot ||
545                  instr->opcode() == HloOpcode::kConvolution;
546         });
547   }
548 
549   // Run the following passes to a fixed point.
550   [&pipeline =
551        pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification")] {
552     AddHloVerifier(&pipeline, HloVerifierOpts{}, /*debug_only=*/true);
553 
554     AlgebraicSimplifierOptions options;
555     options.set_enable_dot_strength_reduction(false);
556     // TODO(b/209827141): XLA:CPU doesn't propagate NaN through min/max, but
557     // other platforms do, so it should be changed.
558     options.set_minmax_propagate_nan(false);
559     pipeline.AddPass<AlgebraicSimplifier>(options);
560     pipeline.AddPass<SortSimplifier>();
561     pipeline.AddPass<HloDCE>();
562     pipeline.AddPass<GatherExpander>(GatherExpander::kEliminateSimpleGathers);
563 
564     // Needs to happen after algebraic simplifier.
565     pipeline.AddPass<TreeReductionRewriter>();
566 
567     // BatchNormExpander can create zero-sized ops, so zero-sized HLO
568     // elimination has to come after that pass.
569     pipeline.AddPass<ZeroSizedHloElimination>();
570 
571     pipeline.AddPass<WhileLoopInvariantCodeMotion>();
572     pipeline.AddPass<TupleSimplifier>();
573     pipeline.AddPass<WhileLoopConstantSinking>();
574     pipeline.AddPass<WhileLoopSimplifier>();
575 
576     // TODO(b/134075051): Re-enable after b/134075051 is fixed.
577     // pipeline.AddPass<SliceSinker>();
578 
579     pipeline.AddPass<HloDCE>();
580     pipeline.AddPass<ReshapeMover>();
581     pipeline.AddPass<HloConstantFolding>();
582     pipeline.AddPass<ConditionalSimplifier>();
583   }();
584   pipeline.AddPass<BitcastDtypesExpander>();
585 
586   // XLA lowers topk to a libcall while the MLIR based pipeline does not yet
587   // support libcalls. Disable this for now.
588   if (!is_mlir_compile) {
589     pipeline.AddPass<TopkRewriter>([](const HloSortInstruction* sort, int64_t) {
590       return sort->operand(0)->shape().element_type() == F32;
591     });
592   }
593   pipeline.AddPass<IndexedArrayAnalysisPrinterPass>();
594   pipeline.AddPass<TransposeFolding>(
595       [&](const HloInstruction& dot, int64_t operand) -> StatusOr<bool> {
596         if (DotImplementationCanHandleTranspose(dot,
597                                                 *target_machine_features)) {
598           return TransposeFolding::IsRowColumnTransposeDotOperand(dot, operand);
599         }
600         return false;
601       },
602       TransposeFolding::NeverFoldTranspose);
603   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
604 
605   pipeline.AddPass<OptimizationBarrierExpander>();
606   pipeline.AddPass<TupleSimplifier>();
607 
608   // Layout assignment uses alias analysis, which requires the call graph to be
609   // flattened.
610   pipeline.AddPass<FlattenCallGraph>();
611   ChannelLayoutConstraints layout_constraints;
612   pipeline.AddPass<CpuLayoutAssignment>(
613       module->mutable_entry_computation_layout(), target_machine_features,
614       &layout_constraints);
615 
616   return pipeline.Run(module).status();
617 }
618 
RunHloPassesAfterLayoutAssn(HloModule * module,bool is_aot_compile,LLVMTargetMachineFeatures * target_machine_features,bool is_mlir_compile)619 Status CpuCompiler::RunHloPassesAfterLayoutAssn(
620     HloModule* module, bool is_aot_compile,
621     LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile) {
622   {
623     HloPassPipeline pipeline("hlo normalization");
624     pipeline.AddPass<ReshapeDecomposer>();
625     pipeline.AddPass<ReduceDecomposer>();
626     pipeline.AddPass<BroadcastCanonicalizer>();
627     TF_RETURN_IF_ERROR(pipeline.Run(module).status());
628   }
629 
630   HloPassPipeline pipeline("HLO passes after layout assignment");
631 
632   // CopyInsertion is still needed by BufferAssignment. MLIR passes will handle
633   // everything else done by XLA, but CopyInsertion is needed to interface with
634   // the existing runtime.
635   if (is_mlir_compile) {
636     pipeline.AddPass<CopyInsertion>();
637     return pipeline.Run(module).status();
638   }
639 
640   // After layout assignment, use a layout-sensitive verifier.
641   pipeline.AddPass<HloPassPipeline>("after layout assignment");
642   AddHloVerifier(&pipeline, HloVerifierOpts{}.MakeLayoutSensitive(),
643                  /*debug_only=*/true);
644 
645   pipeline.AddPass<ReshapeDecomposer>();
646 
647   // Add a fusion pass now that layout assignment is done.
648   pipeline.AddPass<CpuInstructionFusion>();
649 
650   // The LayoutAssignment pass may leave behind kCopy instructions which are
651   // duplicate or NOPs, so remove them with algebraic simplification and CSE.
652   // Run this to a fixed point.
653   [&pipeline = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
654        "simplification after layout assignment")] {
655     AddHloVerifier(
656         &pipeline,
657         HloVerifierOpts{}.MakeLayoutSensitive().WithInstructionCanChangeLayout(
658             LayoutAssignment::InstructionCanChangeLayout),
659         /*debug_only=*/true);
660     AlgebraicSimplifierOptions options;
661     options.set_is_layout_sensitive(true);
662     options.set_enable_dot_strength_reduction(false);
663     // TODO(b/209827141): XLA:CPU doesn't propagate NaN through min/max, but
664     // other platforms do, so it should be changed.
665     options.set_minmax_propagate_nan(false);
666     pipeline.AddPass<AlgebraicSimplifier>(options);
667     pipeline.AddPass<HloDCE>();
668     pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
669   }();
670 
671   // Outline ops in the entry computation into calls to subcomputations.
672   const int max_parallelism =
673       module->config().intra_op_parallelism_threads() > 0
674           ? module->config().intra_op_parallelism_threads()
675           : tensorflow::port::NumSchedulableCPUs();
676   if (!is_aot_compile) {
677     // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
678     // Note this is not run for AOT because it would bring in thread pool
679     // and thread synchronization dependencies which would likely increase
680     // binary size (and most AOT applications are single-threaded).
681     // TODO(b/29630486) Support multi-threaded AOT.
682     pipeline.AddPass<ParallelTaskAssigner>(
683         max_parallelism, ShapeSizeBytesFunction(), target_machine_features);
684   }
685   // Copy insertion should be performed immediately before IR emission to
686   // avoid inserting unnecessary copies (later pass adds an instruction which
687   // materializes the value) or missing a necessary copy (later pass removes
688   // an instruction which materializes a value). DCE must be run immediately
689   // before (and sometime after) copy insertion, to avoid dead code from
690   // interfering with the rewrites.
691   pipeline.AddPass<HloDCE>();
692   pipeline.AddPass<CopyInsertion>();
693   pipeline.AddPass<HloDCE>();
694   return pipeline.Run(module).status();
695 }
696 
RunHloPasses(HloModule * module,bool is_aot_compile,llvm::TargetMachine * target_machine,bool is_mlir_compile)697 Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile,
698                                  llvm::TargetMachine* target_machine,
699                                  bool is_mlir_compile) {
700   LLVMTargetMachineFeatures target_machine_features(target_machine);
701   TF_RETURN_IF_ERROR(RunHloPassesThroughLayoutAssn(
702       module, is_aot_compile, &target_machine_features, is_mlir_compile));
703 
704   return RunHloPassesAfterLayoutAssn(
705       module, is_aot_compile, &target_machine_features,
706       UseMlirHloLowering(is_mlir_compile, module));
707 }
708 
709 namespace {
710 
711 // Align buffers to 16-byte boundaries.
memory_alignment(LogicalBuffer::Color)712 int64_t memory_alignment(LogicalBuffer::Color) {
713   return cpu_function_runtime::MinAlign();
714 }
715 
CompilerTargetOptions(const HloModuleConfig & module_config)716 llvm::TargetOptions CompilerTargetOptions(
717     const HloModuleConfig& module_config) {
718   llvm::TargetOptions target_options;
719   // Always allow FMA fusion. This increases precision instead of decreasing it.
720   target_options.AllowFPOpFusion = llvm::FPOpFusion::Fast;
721   return target_options;
722 }
723 
CodeGenOptLevel(const HloModuleConfig & module_config)724 llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) {
725   VLOG(2) << "backend_optimization_level: "
726           << module_config.debug_options().xla_backend_optimization_level();
727   switch (module_config.debug_options().xla_backend_optimization_level()) {
728     case 1:
729       return llvm::CodeGenOpt::Less;
730     case 2:
731       return llvm::CodeGenOpt::Default;
732     case 3:
733       return llvm::CodeGenOpt::Aggressive;
734     default:
735       return llvm::CodeGenOpt::None;
736   }
737 }
738 
GetIRModuleHooks(const HloModule & hlo_module,const LLVMCompiler::ModuleHook & user_pre_optimization_hook,const LLVMCompiler::ModuleHook & user_post_optimization_hook)739 std::pair<LLVMCompiler::ModuleHook, LLVMCompiler::ModuleHook> GetIRModuleHooks(
740     const HloModule& hlo_module,
741     const LLVMCompiler::ModuleHook& user_pre_optimization_hook,
742     const LLVMCompiler::ModuleHook& user_post_optimization_hook) {
743   // Create the IR hooks. If applicable, each IR hook does the following:
744   //
745   //  * Calls the user supplied module hook.
746   //  * Writes out the IR to a file in the output directory designated by
747   //    --xla_dump_to
748   const HloModule* hlo_module_ptr = &hlo_module;
749   auto hook = [user_pre_optimization_hook, user_post_optimization_hook,
750                hlo_module_ptr](bool optimized,
751                                const llvm::Module& llvm_module) {
752     const auto& user_hook =
753         !optimized ? user_pre_optimization_hook : user_post_optimization_hook;
754     if (user_hook) {
755       user_hook(llvm_module);
756     }
757     llvm_ir::DumpIrIfEnabled(*hlo_module_ptr, llvm_module, optimized);
758   };
759   return {[hook](const llvm::Module& llvm_module) {
760             return hook(/*optimized=*/false, llvm_module);
761           },
762           [hook](const llvm::Module& llvm_module) {
763             return hook(/*optimized=*/true, llvm_module);
764           }};
765 }
766 
VerifyLlvmModule(const llvm::Module & llvm_module)767 Status VerifyLlvmModule(const llvm::Module& llvm_module) {
768   XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier");
769 
770   std::string err;
771   llvm::raw_string_ostream err_stream(err);
772 
773   // verifyModule() returns true if the module is broken.
774   TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream))
775       << "Invalid LLVM IR before optimizations:\n"
776       << err_stream.str()
777       << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
778          "Rerun with --xla_dump_to to get the IR. ";
779   return OkStatus();
780 }
781 
CreateHloProfilingArtifacts(const HloModule & module,absl::flat_hash_map<const HloInstruction *,int64_t> * instruction_to_profile_idx,absl::flat_hash_map<const HloComputation *,int64_t> * computation_to_profile_idx,std::unique_ptr<HloProfileIndexMap> * hlo_profile_index_map,std::unique_ptr<HloProfilePrinterData> * hlo_profile_printer_data)782 Status CreateHloProfilingArtifacts(
783     const HloModule& module,
784     absl::flat_hash_map<const HloInstruction*, int64_t>*
785         instruction_to_profile_idx,
786     absl::flat_hash_map<const HloComputation*, int64_t>*
787         computation_to_profile_idx,
788     std::unique_ptr<HloProfileIndexMap>* hlo_profile_index_map,
789     std::unique_ptr<HloProfilePrinterData>* hlo_profile_printer_data) {
790   *hlo_profile_index_map = std::make_unique<HloProfileIndexMap>(module);
791   const HloComputation& entry_computation = *module.entry_computation();
792 
793   TF_ASSIGN_OR_RETURN(
794       *instruction_to_profile_idx,
795       CollectProfileCandidates::GetCandidatesForComputation(
796           entry_computation,
797           (*hlo_profile_index_map)->instruction_to_profile_idx()));
798 
799   auto shape_size_bytes = [](const Shape& shape) {
800     // On the cpu, opaques are pointers.
801     if (shape.IsOpaque()) {
802       return static_cast<int64_t>(sizeof(void*));
803     }
804     return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
805   };
806 
807   HloCostAnalysis cost_analysis(shape_size_bytes);
808   TF_RETURN_IF_ERROR(entry_computation.Accept(&cost_analysis));
809   *hlo_profile_printer_data = CreateHloProfilePrinterData(
810       **hlo_profile_index_map, cost_analysis, entry_computation.name());
811   *computation_to_profile_idx =
812       (*hlo_profile_index_map)->computation_to_profile_idx();
813 
814   return OkStatus();
815 }
816 
817 }  // namespace
818 
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor *,const CompileOptions &)819 StatusOr<std::unique_ptr<HloModule>> CpuCompiler::RunHloPasses(
820     std::unique_ptr<HloModule> module, se::StreamExecutor* /*stream_exec*/,
821     const CompileOptions& /*options*/) {
822   std::unique_ptr<llvm::TargetMachine> jit_target_machine =
823       SimpleOrcJIT::InferTargetMachineForJIT(
824           CompilerTargetOptions(module->config()),
825           CodeGenOptLevel(module->config()));
826 
827   TF_RETURN_IF_ERROR(RunHloPasses(
828       module.get(), /*is_aot_compile=*/false, jit_target_machine.get(),
829       /*is_mlir_compile=*/
830       module->config().debug_options().xla_cpu_enable_mlir_lowering()));
831   return std::move(module);
832 }
833 
AssignBuffers(const HloModule * module)834 StatusOr<std::unique_ptr<BufferAssignment>> CpuCompiler::AssignBuffers(
835     const HloModule* module) {
836   // Select an order for emitting the HLO instructions for each computation.
837   // Using this sequence enables tighter buffer liveness analysis and reduced
838   // memory usage (as compared to using DependencyHloOrdering).
839   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
840                       ScheduleModule(module, BufferSizeBytesFunction(),
841                                      ComputationSchedulerToModuleScheduler(
842                                          DFSMemoryScheduler)));
843 
844   // Run buffer allocation on the HLO graph.
845   TF_ASSIGN_OR_RETURN(
846       std::unique_ptr<BufferAssignment> assignment,
847       BufferAssigner::Run(module,
848                           std::make_unique<SequentialHloOrdering>(schedule),
849                           BufferSizeBytesFunction(), memory_alignment,
850                           /*allocate_buffers_for_constants=*/true));
851 
852   return std::move(assignment);
853 }
854 
855 namespace {
856 
857 // Post-compilation callback functor for use by SimpleOrcJIT.
858 //
859 // Dumps machine code if dumping is enabled for the module.
860 struct OrcJITPostCompilationHook {
861   // Gets an std::function that implements this hook.
Createxla::cpu::__anon660623761311::OrcJITPostCompilationHook862   static std::function<void(const llvm::object::ObjectFile& obj_file)> Create(
863       const HloModule* module) {
864     // This struct is not copyable, but std::functions must be.  So to create an
865     // std::function out of this struct, we have to wrap it in a shared_ptr.
866     auto wrapped = std::make_shared<OrcJITPostCompilationHook>(module);
867     return [wrapped](const llvm::object::ObjectFile& obj_file) {
868       (*wrapped)(obj_file);
869     };
870   }
871 
872   // Constructor can't be private because we want to call it from
873   // std::make_shared, but users should call Create() instead.
OrcJITPostCompilationHookxla::cpu::__anon660623761311::OrcJITPostCompilationHook874   explicit OrcJITPostCompilationHook(const HloModule* module)
875       : module(module) {}
876 
877  private:
operator ()xla::cpu::__anon660623761311::OrcJITPostCompilationHook878   void operator()(const llvm::object::ObjectFile& obj_file) {
879     if (!DumpingEnabledForHloModule(*module)) {
880       return;
881     }
882     DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
883                     absl::string_view(obj_file.getData().data(),
884                                       obj_file.getData().size()));
885   }
886 
887   const HloModule* module;
888 };
889 
InitializeLLVMCommandLineOptions(const HloModuleConfig & config)890 void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
891   llvm_ir::InitializeLLVMCommandLineOptions(
892       config.debug_options().xla_backend_extra_options());
893 }
894 
LowerMLIRModule(mlir::ModuleOp mlir_module,mlir::MLIRContext & mlir_context)895 Status LowerMLIRModule(mlir::ModuleOp mlir_module,
896                        mlir::MLIRContext& mlir_context) {
897   LoadMLIRDialects(mlir_context);
898   mlir::PassManager pm(&mlir_context);
899   // Resolve all shape constraints (e.g. broadcast constraints that can be
900   // proved statically and changed to const witness) early to allow more
901   // efficient broadcast operations moving.
902   // Move up broadcasting operations to allow for more fusion opportunities.
903   pm.addPass(mlir::createInlinerPass());
904   pm.addPass(mlir::mhlo::createExpandHloTuplesPass("main"));
905   // TODO(b/233771980): Remove once custom_call doesn't use tuples.
906   pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createFlattenTuplePass());
907   pm.addNestedPass<mlir::func::FuncOp>(
908       mlir::mhlo::createLegalizeGeneralDotPass());
909   pm.addNestedPass<mlir::func::FuncOp>(
910       mlir::mhlo::createBroadcastPropagationPass());
911   pm.addPass(mlir::createCSEPass());
912   pm.addPass(mlir::createCanonicalizerPass());
913 
914   // Transform HLO operations to Linalg.
915   pm.addNestedPass<mlir::func::FuncOp>(mlir::mhlo::createLegalizeSortPass());
916   pm.addNestedPass<mlir::func::FuncOp>(
917       mlir::mhlo::createLegalizeControlFlowPass());
918   pm.addPass(::mlir::mhlo::createLegalizeToArithmeticPass());
919   pm.addNestedPass<mlir::func::FuncOp>(
920       mlir::mhlo::createLegalizeHloToLinalgPass());
921 
922   // Lower index cast on tensors to tensor.generate.
923   pm.addNestedPass<mlir::func::FuncOp>(mlir::createLowerIndexCastPass());
924 
925   pm.addPass(mlir::mhlo::createConvertToSignlessPass());
926 
927   // Lower shape dialect to standard to enable linalg canonicalizations (e.g.
928   // use linalg inputs instead of outputs for memref.dim operations).
929   pm.addNestedPass<mlir::func::FuncOp>(mlir::createShapeSimplification());
930   pm.addNestedPass<mlir::func::FuncOp>(mlir::createShapeToShapeLowering());
931   pm.addPass(mlir::createConvertShapeToStandardPass());
932   pm.addNestedPass<mlir::func::FuncOp>(
933       mlir::createConvertShapeConstraintsPass());
934 
935   // Fuse Linalg on tensors operations.
936   pm.addPass(mlir::createCSEPass());
937   pm.addPass(mlir::memref::createResolveShapedTypeResultDimsPass());
938   pm.addPass(mlir::createCanonicalizerPass());
939   pm.addNestedPass<mlir::func::FuncOp>(
940       mlir::createLinalgElementwiseOpFusionPass());
941   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
942   pm.addPass(mlir::createConvertTensorToLinalgPass());
943   pm.addNestedPass<mlir::func::FuncOp>(
944       mlir::createLinalgInitTensorToAllocTensorPass());
945 
946   // Always run canonicalizer (which does dead code removal) before
947   // bufferizing anything.
948   pm.addPass(mlir::createCanonicalizerPass());
949   pm.addPass(mlir::hlo::createOneShotBufferizePass());
950 
951   // Handle framework specific requirements for buffers and then insert
952   // deallocations for temporary buffers.
953   pm.addNestedPass<mlir::func::FuncOp>(mlir::createConvertLinalgToLoopsPass());
954   pm.addPass(mlir::createCSEPass());
955   pm.addPass(mlir::createCanonicalizerPass());
956   pm.addPass(mlir::bufferization::createBufferResultsToOutParamsPass());
957   pm.addPass(mlir::mhlo::CreateOutlineWithXLAFrameworkPass());
958   pm.addPass(mlir::createInlinerPass());
959   pm.addNestedPass<mlir::func::FuncOp>(
960       mlir::bufferization::createBufferDeallocationPass());
961 
962   pm.addPass(mlir::createBufferizationToMemRefPass());
963 
964   // Specilize linalg.matmul to linalg.dot, linalg.matvec or linalg.vecmat,
965   // and immediately canonicalize to clean up not taken branches.
966   // pm.addNestedPass<mlir::func::FuncOp>(CreateLinalgMatmulSpecializationPass());
967   pm.addPass(mlir::createCanonicalizerPass());
968 
969   // Tile and vectorize linalg operation using Linalg Codegen Strategy.
970   // pm.addNestedPass<mlir::func::FuncOp>(CreateCodegenStrategyForMatMulPass());
971 
972   // TODO(tpopp): Move hits to mlir::hlo::createGenericHostToLLVMPass?
973   pm.addNestedPass<mlir::func::FuncOp>(
974       mlir::createConvertComplexToStandardPass());
975 
976   pm.addPass(mlir::createCSEPass());
977   pm.addPass(mlir::createCanonicalizerPass());
978 
979   mlir::VectorTransferToSCFOptions vec_to_scf_options;
980   vec_to_scf_options.unroll = true;
981   pm.addNestedPass<mlir::func::FuncOp>(
982       mlir::createConvertVectorToSCFPass(vec_to_scf_options));
983   pm.addNestedPass<mlir::func::FuncOp>(
984       mlir::arith::createArithmeticExpandOpsPass());
985   pm.addNestedPass<mlir::func::FuncOp>(mlir::memref::createExpandOpsPass());
986   pm.addNestedPass<mlir::func::FuncOp>(mlir::createLowerAffinePass());
987   pm.addPass(mlir::mhlo::CreateLegalizeXLAFrameworkToLLVMPass());
988   pm.addPass(mlir::hlo::createGenericHostToLLVMPass());
989   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
990   if (pm.run(mlir_module).failed()) {
991     mlir_module->dump();
992     return tensorflow::errors::Internal(
993         "Failed to compile through MLIR pipeline");
994   }
995 
996   return OkStatus();
997 }
998 
createMLIRModule(HloModule * module,mlir::MLIRContext & mlir_context,BufferAssignment * assignment)999 StatusOr<mlir::OwningOpRef<mlir::ModuleOp>> createMLIRModule(
1000     HloModule* module, mlir::MLIRContext& mlir_context,
1001     BufferAssignment* assignment) {
1002   LoadMLIRDialects(mlir_context);
1003   mlir::OpBuilder builder(&mlir_context);
1004   auto mlir_module = builder.create<mlir::ModuleOp>(builder.getUnknownLoc());
1005   TF_RETURN_IF_ERROR(ConvertHloToMlirHlo(mlir_module, module));
1006 
1007   // Add buffer mappings. The first attribute is the index of the slice, the
1008   // second is a boolean attribute on whether the allocation is writeable.
1009   llvm::SmallVector<std::pair<mlir::Attribute, mlir::Attribute>>
1010       operand_mapping;
1011   for (auto i : module->entry_computation()->parameter_instructions()) {
1012     auto slice = assignment->GetUniqueTopLevelSlice(i);
1013     operand_mapping.emplace_back(
1014         builder.getI32IntegerAttr(static_cast<int32_t>(slice->index())),
1015         builder.getBoolAttr(!slice->allocation()->is_readonly()));
1016   }
1017 
1018   auto root_instr = module->entry_computation()->root_instruction();
1019   auto output_allocation = assignment->GetUniqueTopLevelOutputSlice();
1020 
1021   // Gather mappings to each element in the tuple if necessary
1022   llvm::SmallVector<mlir::Attribute> result_inner_mapping;
1023   if (output_allocation->allocation()->is_tuple()) {
1024     for (auto i : llvm::seq<int>(0, root_instr->shape().tuple_shapes_size())) {
1025       result_inner_mapping.push_back(mlir::IntegerAttr::get(
1026           mlir::IntegerType::get(&mlir_context, 64),
1027           assignment->GetUniqueSlice(root_instr, {i})->index()));
1028     }
1029   }
1030 
1031   auto result_mapping = builder.getI32IntegerAttr(
1032       static_cast<int32_t>(output_allocation->index()));
1033   mlir_module->walk([&](mlir::func::FuncOp f) {
1034     if (f.getSymName() == "main") {
1035       for (auto& p : llvm::enumerate(operand_mapping)) {
1036         f.setArgAttr(p.index(), "xla_framework.input_mapping", p.value().first);
1037         // Mark argument as (non-)writeable for bufferization. This ensures that
1038         // entry parameters are not overwritten.
1039         f.setArgAttr(p.index(), "bufferization.writable", p.value().second);
1040       }
1041       f->setAttr("xla_framework.result_mapping", result_mapping);
1042     }
1043 
1044     if (output_allocation->allocation()->is_tuple()) {
1045       f->setAttr("xla_framework.result_inner_mapping",
1046                  mlir::ArrayAttr::get(f.getContext(), result_inner_mapping));
1047     }
1048   });
1049   return {mlir_module};
1050 }
1051 
1052 struct ComputationToEmit {
1053   HloComputation* computation;
1054 
1055   // Are we emitting this computation with fast-math reassociation enabled?
1056   // We enable reassociation for reductions because it has a significant
1057   // performance impact.
1058   bool allow_reassociation;
1059 
operator ==xla::cpu::__anon660623761311::ComputationToEmit1060   bool operator==(const ComputationToEmit& other) const {
1061     return computation == other.computation &&
1062            allow_reassociation == other.allow_reassociation;
1063   }
1064 
1065   template <typename H>
AbslHashValue(H h,const ComputationToEmit & c)1066   friend H AbslHashValue(H h, const ComputationToEmit& c) {
1067     return H::combine(std::move(h), c.computation, c.allow_reassociation);
1068   }
1069 };
1070 
SubcomputationEmissionOrder(HloComputation * root)1071 std::vector<ComputationToEmit> SubcomputationEmissionOrder(
1072     HloComputation* root) {
1073   absl::flat_hash_set<ComputationToEmit> visited;
1074   std::vector<ComputationToEmit> postorder;
1075 
1076   // agenda of (node, leave) pairs.
1077   std::stack<std::pair<ComputationToEmit, bool>> agenda;
1078   agenda.emplace(ComputationToEmit{root, false}, false);
1079   while (!agenda.empty()) {
1080     ComputationToEmit c;
1081     bool leave;
1082     std::tie(c, leave) = agenda.top();
1083     agenda.pop();
1084 
1085     if (leave) {
1086       postorder.push_back(c);
1087       continue;
1088     }
1089 
1090     if (visited.insert(c).second) {
1091       agenda.emplace(c, true);
1092       for (auto* instruction : c.computation->instructions()) {
1093         bool allow_reassociation =
1094             instruction->opcode() == HloOpcode::kAllReduce ||
1095             instruction->opcode() == HloOpcode::kReduce ||
1096             instruction->opcode() == HloOpcode::kReduceWindow;
1097         for (auto it = instruction->called_computations().rbegin();
1098              it != instruction->called_computations().rend(); ++it) {
1099           HloComputation* called_computation = *it;
1100           ComputationToEmit callee{
1101               called_computation, c.allow_reassociation || allow_reassociation};
1102           if (!visited.contains(callee)) {
1103             agenda.emplace(callee, false);
1104           }
1105         }
1106       }
1107     }
1108   }
1109   DCHECK(!postorder.empty() && postorder.back().computation == root);
1110   postorder.pop_back();
1111   return postorder;
1112 }
1113 
1114 }  // namespace
1115 
1116 StatusOr<std::unique_ptr<CpuExecutable>>
CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module)1117 CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
1118   ModuleHook pre_optimization_ir_hook;
1119   ModuleHook post_optimization_ir_hook;
1120   std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
1121       GetIRModuleHooks(*module, user_pre_optimization_hook_,
1122                        user_post_optimization_hook_);
1123 
1124   // Compile must be thread-safe so create a new LLVM context for the module.
1125   mlir::MLIRContext mlir_context;
1126   LoadMLIRDialects(mlir_context);
1127   auto llvm_context = std::make_unique<llvm::LLVMContext>();
1128   auto llvm_module =
1129       std::make_unique<llvm::Module>("__compute_module", *llvm_context);
1130 
1131   auto jit = SimpleOrcJIT::Create(
1132       CompilerTargetOptions(module->config()),
1133       CodeGenOptLevel(module->config()),
1134       options::OptimizeForSizeRequested(module->config()),
1135       module->config().debug_options().xla_llvm_disable_expensive_passes(),
1136       llvm_ir::GetCpuFastMathFlags(module->config()), pre_optimization_ir_hook,
1137       post_optimization_ir_hook,
1138       OrcJITPostCompilationHook::Create(module.get()));
1139   if (!jit) {
1140     return InternalError("Creating JIT failed: %s",
1141                          llvm::toString(jit.takeError()));
1142   }
1143   llvm_module->setDataLayout((*jit)->data_layout());
1144   llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
1145 
1146   HloComputation* entry_computation = module->entry_computation();
1147   absl::flat_hash_map<const HloInstruction*, int64_t>
1148       instruction_to_profile_idx;
1149   absl::flat_hash_map<const HloComputation*, int64_t>
1150       computation_to_profile_idx;
1151   std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
1152   std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
1153   if (module->config().hlo_profiling_enabled()) {
1154     TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
1155         *module, &instruction_to_profile_idx, &computation_to_profile_idx,
1156         &hlo_profile_index_map, &hlo_profile_printer_data));
1157   }
1158 
1159   // Cache these flags here since we'll want to access them after the module's
1160   // ownership is std::moved.
1161   const bool embed_ir_in_executable =
1162       module->config().debug_options().xla_embed_ir_in_executable();
1163 
1164   // Select an order for emitting the HLO instructions for each
1165   // computation. Using this sequence enables tighter buffer liveness analysis
1166   // and reduced memory usage (as compared to using DependencyHloOrdering).
1167   TF_ASSIGN_OR_RETURN(HloSchedule schedule,
1168                       ScheduleModule(module.get(), BufferSizeBytesFunction(),
1169                                      ComputationSchedulerToModuleScheduler(
1170                                          DFSMemoryScheduler)));
1171 
1172   // Run buffer allocation on the HLO graph.
1173   TF_ASSIGN_OR_RETURN(
1174       std::unique_ptr<BufferAssignment> assignment,
1175       BufferAssigner::Run(module.get(),
1176                           std::make_unique<SequentialHloOrdering>(schedule),
1177                           BufferSizeBytesFunction(), memory_alignment,
1178                           /*allocate_buffers_for_constants=*/true));
1179   DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
1180 
1181   // Each computation is a single function.  Emit all embedded computations
1182   // before the entry computation. The order of computations returned from
1183   // GetEmbeddedComputations guarantees that a called computation occurs
1184   // before a caller computation.
1185 
1186   std::string function_name;
1187   if (UseMlirHloLowering(
1188           module->config().debug_options().xla_cpu_enable_mlir_lowering(),
1189           module.get())) {
1190     TF_ASSIGN_OR_RETURN(
1191         auto mlir_module,
1192         createMLIRModule(module.get(), mlir_context, assignment.get()));
1193     TF_RETURN_IF_ERROR(LowerMLIRModule(*mlir_module, mlir_context));
1194 
1195     function_name = entry_computation->name();
1196     // TODO(kramerb): Don't rely on the exact function name.
1197     llvm::cast<mlir::LLVM::LLVMFuncOp>(
1198         mlir_module->lookupSymbol("main_xla_framework"))
1199         .setName(function_name);
1200 
1201     llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, *llvm_context);
1202     if (!llvm_module) {
1203       return InternalError("Translation to LLVM IR failed");
1204     }
1205     llvm_module->setDataLayout((*jit)->data_layout());
1206     llvm_module->setTargetTriple((*jit)->target_triple().getTriple());
1207   } else {
1208     LLVMTargetMachineFeatures target_machine_features((*jit)->target_machine());
1209     IrEmitter ir_emitter(
1210         &mlir_context, *module, *assignment, llvm_module.get(),
1211         std::move(instruction_to_profile_idx),
1212         std::move(computation_to_profile_idx),
1213         ModuleComputationsTransitivelyContainCustomCall(*module),
1214         &target_machine_features,
1215 #ifdef MEMORY_SANITIZER
1216         /*emit_code_for_msan=*/true
1217 #else
1218         /*emit_code_for_msan=*/false
1219 #endif
1220     );
1221 
1222     TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
1223 
1224     for (ComputationToEmit subcomputation :
1225          SubcomputationEmissionOrder(entry_computation)) {
1226       if (subcomputation.computation->IsFusionComputation()) {
1227         continue;
1228       }
1229       TF_RETURN_IF_ERROR(
1230           ir_emitter
1231               .EmitComputation(
1232                   subcomputation.computation,
1233                   subcomputation.computation->name(),
1234                   /*is_top_level_computation=*/false,
1235                   schedule.sequence(subcomputation.computation).instructions(),
1236                   subcomputation.allow_reassociation)
1237               .status());
1238     }
1239     std::string function_name_prefix = entry_computation->name().empty()
1240                                            ? "__compute"
1241                                            : entry_computation->name();
1242     TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
1243                         ir_emitter.EmitComputation(
1244                             entry_computation, function_name_prefix,
1245                             /*is_top_level_computation=*/true,
1246                             schedule.sequence(entry_computation).instructions(),
1247                             /*allow_reassociation=*/false));
1248 
1249     function_name = [&]() {
1250       llvm::SmallVector<char, 40> function_name_vector;
1251       llvm::Mangler::getNameWithPrefix(function_name_vector,
1252                                        entry_function->getName(),
1253                                        (*jit)->data_layout());
1254       return std::string(function_name_vector.begin(),
1255                          function_name_vector.end());
1256     }();
1257   }
1258 
1259   std::string ir_module_string;
1260   if (embed_ir_in_executable) {
1261     ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
1262   }
1263 
1264   TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
1265 
1266   // JIT compile the LLVM IR module to in-memory machine code.
1267   llvm::orc::ThreadSafeModule thread_safe_module(std::move(llvm_module),
1268                                                  std::move(llvm_context));
1269   cantFail((*jit)->AddModule(std::move(thread_safe_module)));
1270 
1271   auto cpu_executable = std::make_unique<CpuExecutable>(
1272       std::move(*jit), std::move(assignment), std::move(module), function_name,
1273       std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map));
1274 
1275   if (embed_ir_in_executable) {
1276     cpu_executable->set_ir_module_string(ir_module_string);
1277   }
1278 
1279   // Dump computation proto state and buffer assignment for debug and test, if
1280   // dump or embed_ir_in_executable is enabled.
1281   if (embed_ir_in_executable ||
1282       DumpingEnabledForHloModule(cpu_executable->module())) {
1283     auto hlo_proto = std::make_unique<HloProto>();
1284     *hlo_proto->mutable_hlo_module() = cpu_executable->module().ToProto();
1285     *hlo_proto->mutable_buffer_assignment() =
1286         cpu_executable->buffer_assignment().ToProto();
1287     cpu_executable->set_hlo_proto(std::move(hlo_proto));
1288   }
1289 
1290   return cpu_executable;
1291 }
1292 
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * stream_exec,const CompileOptions & options)1293 StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
1294     std::unique_ptr<HloModule> module,
1295     [[maybe_unused]] se::StreamExecutor* stream_exec,
1296     [[maybe_unused]] const CompileOptions& options) {
1297   VLOG(1) << "Compiling: " << module->name();
1298   XLA_SCOPED_LOGGING_TIMER(
1299       absl::StrFormat("Compiling [%s] for CPU using JIT", module->name()));
1300   std::string slow_compilation_msg =
1301       absl::StrCat("Compiling module ", module->name());
1302   auto slow_compile_alarm = SlowCompilationAlarm(slow_compilation_msg);
1303 
1304   absl::call_once(llvm_command_line_options_initialized,
1305                   &InitializeLLVMCommandLineOptions, module->config());
1306 
1307   std::unique_ptr<CpuExecutable> cpu_executable;
1308   TF_ASSIGN_OR_RETURN(cpu_executable,
1309                       CompileLegacyCpuExecutable(std::move(module)));
1310 
1311   cpu_executable->set_debug_info(
1312       cpu_executable->buffer_assignment().GetStats().ToString());
1313   VLOG(1) << "Compilation finished";
1314   return std::unique_ptr<Executable>(std::move(cpu_executable));
1315 }
1316 
1317 StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,const AotCompilationOptions & aot_options)1318 CpuCompiler::CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
1319                                 const AotCompilationOptions& aot_options) {
1320   TF_RET_CHECK(!module_group->empty());
1321   std::vector<std::unique_ptr<HloModule>> modules =
1322       module_group->ConsumeModules();
1323 
1324   absl::call_once(llvm_command_line_options_initialized,
1325                   &InitializeLLVMCommandLineOptions, modules[0]->config());
1326 
1327   // We can pass just one llvm::TargetOptions when we compile the LLVM module,
1328   // so we bail if the configs have conflicting flags. At the moment, the only
1329   // flags that need to be consistent are for fast-math.
1330   for (const auto& fn_and_name :
1331        {std::make_pair(&DebugOptions::xla_cpu_enable_fast_math,
1332                        "xla_cpu_enable_fast_math"),
1333         std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_infs,
1334                        "xla_cpu_fast_math_honor_infs"),
1335         std::make_pair(&DebugOptions::xla_cpu_fast_math_honor_nans,
1336                        "xla_cpu_fast_math_honor_nans")}) {
1337     // This only works because each of the method pointers above returns a bool.
1338     // Otherwise we'd have to do some template magic.
1339     const auto& field_method_ptr = fn_and_name.first;
1340     const auto& field_name = fn_and_name.second;
1341     bool first_module_val =
1342         (modules[0]->config().debug_options().*field_method_ptr)();
1343     for (int64_t i = 0; i < modules.size(); ++i) {
1344       bool cur_module_val =
1345           (modules[i]->config().debug_options().*field_method_ptr)();
1346       if (first_module_val != cur_module_val) {
1347         return InvalidArgument(
1348             "All HLO module configs must have the same value for %s, but "
1349             "module 0 and %d have different values (%d vs %d).",
1350             field_name, i, first_module_val, cur_module_val);
1351       }
1352     }
1353   }
1354 
1355   if (aot_options.PlatformId() != se::host::kHostPlatformId) {
1356     return InvalidArgument("Incompatible AOT compilation platform");
1357   }
1358   const CpuAotCompilationOptions& options =
1359       static_cast<const CpuAotCompilationOptions&>(aot_options);
1360   llvm::Triple triple(llvm::Triple::normalize(options.triple()));
1361   std::string error;
1362   const llvm::Target* target =
1363       llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
1364   if (target == nullptr) {
1365     return InternalError("TargetRegistry::lookupTarget failed: %s", error);
1366   }
1367 
1368   llvm::Reloc::Model reloc_model = llvm::Reloc::Static;
1369   llvm::PICLevel::Level pic_level = llvm::PICLevel::NotPIC;
1370   llvm::PIELevel::Level pie_level = llvm::PIELevel::Default;
1371   switch (options.relocation_model()) {
1372     case CpuAotCompilationOptions::RelocationModel::Static:
1373       reloc_model = llvm::Reloc::Static;
1374       pic_level = llvm::PICLevel::NotPIC;
1375       pie_level = llvm::PIELevel::Default;
1376       break;
1377     case CpuAotCompilationOptions::RelocationModel::SmallPic:
1378       reloc_model = llvm::Reloc::PIC_;
1379       pic_level = llvm::PICLevel::SmallPIC;
1380       pie_level = llvm::PIELevel::Default;
1381       break;
1382     case CpuAotCompilationOptions::RelocationModel::BigPic:
1383       reloc_model = llvm::Reloc::PIC_;
1384       pic_level = llvm::PICLevel::BigPIC;
1385       pie_level = llvm::PIELevel::Default;
1386       break;
1387     case CpuAotCompilationOptions::RelocationModel::SmallPie:
1388       reloc_model = llvm::Reloc::PIC_;
1389       pic_level = llvm::PICLevel::SmallPIC;
1390       pie_level = llvm::PIELevel::Small;
1391       break;
1392     case CpuAotCompilationOptions::RelocationModel::BigPie:
1393       reloc_model = llvm::Reloc::PIC_;
1394       pic_level = llvm::PICLevel::BigPIC;
1395       pie_level = llvm::PIELevel::Large;
1396       break;
1397   }
1398   llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel(modules[0]->config());
1399   std::unique_ptr<llvm::TargetMachine> target_machine =
1400       absl::WrapUnique(target->createTargetMachine(
1401           triple.getTriple(), options.cpu_name(), options.features(),
1402           CompilerTargetOptions(modules[0]->config()), reloc_model, llvm::None,
1403           opt_level));
1404 
1405   // Compile must be thread-safe so create a new LLVM context for the module.
1406   mlir::MLIRContext mlir_context;
1407   LoadMLIRDialects(mlir_context);
1408   llvm::LLVMContext llvm_context;
1409   std::unique_ptr<llvm::Module> llvm_module;
1410 
1411   std::vector<std::unique_ptr<AotCompilationResult>> results;
1412   for (size_t i = 0; i < modules.size(); ++i) {
1413     HloModule* module = modules[i].get();
1414     VLOG(1) << "Compiling ahead-of-time: " << module->name();
1415 
1416     TF_RETURN_IF_ERROR(
1417         RunHloPasses(module, /*is_aot_compile=*/true, target_machine.get(),
1418                      /*is_mlir_compile=*/options.use_mlir_hlo_lowering()));
1419 
1420     TF_ASSIGN_OR_RETURN(HloSchedule schedule,
1421                         ScheduleModule(module, BufferSizeBytesFunction()));
1422 
1423     // Run buffer analysis on the HLO graph. This analysis figures out which
1424     // temporary buffers are required to run the computation.
1425     TF_ASSIGN_OR_RETURN(
1426         std::unique_ptr<BufferAssignment> assignment,
1427         BufferAssigner::Run(module,
1428                             std::make_unique<SequentialHloOrdering>(schedule),
1429                             BufferSizeBytesFunction(), memory_alignment,
1430                             /*allocate_buffers_for_constants=*/true));
1431     // BufferAssignment::ToString() includes a header, so no need for us to
1432     // print one ourselves.
1433     if (DumpingEnabledForHloModule(*module)) {
1434       DumpToFileInDirOrStdout(*module, "", "buffer_assignment",
1435                               assignment->ToString());
1436     }
1437     DumpHloModuleIfEnabled(*module, *assignment, "cpu_after_optimizations");
1438 
1439     absl::flat_hash_map<const HloInstruction*, int64_t>
1440         instruction_to_profile_idx;
1441     absl::flat_hash_map<const HloComputation*, int64_t>
1442         computation_to_profile_idx;
1443     std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
1444     std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data;
1445 
1446     if (module->config().hlo_profiling_enabled()) {
1447       TF_RETURN_IF_ERROR(CreateHloProfilingArtifacts(
1448           *module, &instruction_to_profile_idx, &computation_to_profile_idx,
1449           &hlo_profile_index_map, &hlo_profile_printer_data));
1450     }
1451 
1452     LLVMTargetMachineFeatures target_machine_features(target_machine.get());
1453     std::vector<BufferInfo> buffer_infos =
1454         CreateBufferInfosFromBufferAssignment(*assignment);
1455     HloComputation* computation = module->entry_computation();
1456 
1457     if (UseMlirHloLowering(options.use_mlir_hlo_lowering(), module)) {
1458       TF_ASSIGN_OR_RETURN(
1459           auto mlir_module,
1460           createMLIRModule(module, mlir_context, assignment.get()));
1461       TF_RETURN_IF_ERROR(LowerMLIRModule(*mlir_module, mlir_context));
1462 
1463       llvm::cast<mlir::LLVM::LLVMFuncOp>(
1464           mlir_module->lookupSymbol("main_xla_framework"))
1465           .setName(options.entry_point_name());
1466 
1467       llvm_module = mlir::translateModuleToLLVMIR(*mlir_module, llvm_context);
1468       // Set missing information
1469       llvm_module->setDataLayout(target_machine->createDataLayout());
1470       llvm_module->setTargetTriple(triple.getTriple());
1471       if (pic_level != llvm::PICLevel::NotPIC) {
1472         llvm_module->setPICLevel(pic_level);
1473       }
1474       if (pie_level != llvm::PIELevel::Default) {
1475         llvm_module->setPIELevel(pie_level);
1476       }
1477     } else {
1478       // Set required information before emitting IR
1479       llvm_module =
1480           std::make_unique<llvm::Module>("__compute_module", llvm_context);
1481       llvm_module->setDataLayout(target_machine->createDataLayout());
1482       llvm_module->setTargetTriple(triple.getTriple());
1483       if (pic_level != llvm::PICLevel::NotPIC) {
1484         llvm_module->setPICLevel(pic_level);
1485       }
1486       if (pie_level != llvm::PIELevel::Default) {
1487         llvm_module->setPIELevel(pie_level);
1488       }
1489       IrEmitter ir_emitter(
1490           &mlir_context, *module, *assignment, llvm_module.get(),
1491           std::move(instruction_to_profile_idx),
1492           std::move(computation_to_profile_idx),
1493           ModuleComputationsTransitivelyContainCustomCall(*module),
1494           &target_machine_features,
1495           // TODO(b/66051036): Run full msan for AOT.
1496           /*emit_code_for_msan=*/false);
1497 
1498       TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
1499 
1500       for (ComputationToEmit subcomputation :
1501            SubcomputationEmissionOrder(computation)) {
1502         if (subcomputation.computation->IsFusionComputation()) {
1503           continue;
1504         }
1505         TF_RETURN_IF_ERROR(
1506             ir_emitter
1507                 .EmitComputation(subcomputation.computation,
1508                                  subcomputation.computation->name(),
1509                                  /*is_top_level_computation=*/false,
1510                                  schedule.sequence(subcomputation.computation)
1511                                      .instructions(),
1512                                  subcomputation.allow_reassociation)
1513                 .status());
1514       }
1515       const std::string& entry_point_name = options.entry_point_name();
1516       TF_ASSIGN_OR_RETURN(llvm::Function * entry_function,
1517                           ir_emitter.EmitComputation(
1518                               computation, entry_point_name,
1519                               /*is_top_level_computation=*/true,
1520                               schedule.sequence(computation).instructions(),
1521                               /*allow_reassociation=*/false));
1522 
1523       CHECK(entry_function->getName() == entry_point_name);
1524     }
1525 
1526     ModuleHook pre_optimization_ir_hook;
1527     ModuleHook post_optimization_ir_hook;
1528     std::tie(pre_optimization_ir_hook, post_optimization_ir_hook) =
1529         GetIRModuleHooks(*module, user_pre_optimization_hook_,
1530                          user_post_optimization_hook_);
1531 
1532     // Run the LLVM verifier over the unoptimized LLVM IR.  If it fails, run
1533     // the pre-optimization IR dump hook before returning.
1534     {
1535       Status verify_status = VerifyLlvmModule(*llvm_module);
1536       if (!verify_status.ok() && pre_optimization_ir_hook) {
1537         pre_optimization_ir_hook(*llvm_module);
1538       }
1539       TF_RETURN_IF_ERROR(verify_status);
1540     }
1541 
1542     auto post_codegen_hook = [&](const llvm::object::ObjectFile& obj_file) {
1543       if (!DumpingEnabledForHloModule(*module)) {
1544         return;
1545       }
1546       DumpToFileInDir(*module, /*file_prefix=*/"", /*file_suffix=*/"o",
1547                       absl::string_view(obj_file.getData().data(),
1548                                         obj_file.getData().size()));
1549     };
1550 
1551     CompilerFunctor compiler_functor(
1552         target_machine.get(), opt_level,
1553         options::OptimizeForSizeRequested(module->config()),
1554         module->config().debug_options().xla_llvm_disable_expensive_passes(),
1555         llvm_ir::GetCpuFastMathFlags(module->config()),
1556         pre_optimization_ir_hook, post_optimization_ir_hook, post_codegen_hook,
1557         aot_options.sanitize_dataflow(),
1558         aot_options.sanitize_abilists_dataflow());
1559     std::unique_ptr<llvm::MemoryBuffer> object_file =
1560         cantFail(compiler_functor(*llvm_module));
1561     ObjectFileData object_file_data(object_file->getBufferStart(),
1562                                     object_file->getBufferEnd());
1563 
1564     TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
1565                         assignment->GetUniqueTopLevelOutputSlice());
1566 
1567     results.emplace_back(std::make_unique<CpuAotCompilationResult>(
1568         std::move(object_file_data), std::move(buffer_infos),
1569         result_slice.index(), std::move(hlo_profile_printer_data)));
1570   }
1571 
1572   VLOG(1) << "Compilation finished";
1573   return std::move(results);
1574 }
1575 
PlatformId() const1576 se::Platform::Id CpuCompiler::PlatformId() const {
1577   return se::host::kHostPlatformId;
1578 }
1579 
ShapeSizeBytesFunction() const1580 HloCostAnalysis::ShapeSizeFunction CpuCompiler::ShapeSizeBytesFunction() const {
1581   return CpuExecutable::ShapeSizeBytes;
1582 }
1583 
1584 }  // namespace cpu
1585 }  // namespace xla
1586 
InitModule()1587 static bool InitModule() {
1588   xla::Compiler::RegisterCompilerFactory(
1589       stream_executor::host::kHostPlatformId,
1590       []() { return std::make_unique<xla::cpu::CpuCompiler>(); });
1591   return true;
1592 }
1593 static bool module_initialized = InitModule();
1594