xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/cpu_compiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_
18 
19 #include <memory>
20 
21 #include "absl/types/span.h"
22 #include "llvm/Target/TargetMachine.h"
23 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
24 #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
25 #include "tensorflow/compiler/xla/service/executable.h"
26 #include "tensorflow/compiler/xla/service/hlo_module.h"
27 #include "tensorflow/compiler/xla/service/llvm_compiler.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
30 
31 namespace xla {
32 namespace cpu {
33 
34 class CpuExecutable;
35 
36 // This class wraps the configurability options that LLVM exposes including: the
37 // target triple, the target cpu and the target features.  It also includes the
38 // desired linkage name for the computation entry point.
39 class CpuAotCompilationOptions : public AotCompilationOptions {
40  public:
41   // Relocation models available for compilation.
42   enum class RelocationModel {
43     // Corresponds to the -fno-pic compiler option.
44     Static,
45     // Corresponds to the -fpic compiler option.
46     SmallPic,
47     // Corresponds to the -fPIC compiler option.
48     BigPic,
49     // Corresponds to the -fpie compiler option.
50     SmallPie,
51     // Corresponds to the -fPIE compiler option.
52     BigPie
53   };
54 
55   CpuAotCompilationOptions(std::string triple, std::string cpu_name,
56                            std::string features, std::string entry_point_name,
57                            RelocationModel relocation_model);
58 
59   ~CpuAotCompilationOptions() override;
60 
61   se::Platform::Id PlatformId() const override;
62 
63   // The triple used for compilation, similar to clang's -target flag.
triple()64   const std::string& triple() const { return triple_; }
65   // The CPU name used for compilation, similar to clang's -mcpu flag.
cpu_name()66   const std::string& cpu_name() const { return cpu_name_; }
67   // The target features used for compilation ("+avx2", "+neon", etc).
features()68   const std::string& features() const { return features_; }
69   // The name to be used for the compiled code's entry point.
entry_point_name()70   const std::string& entry_point_name() const { return entry_point_name_; }
71   // The relocation model used for compilation.
relocation_model()72   RelocationModel relocation_model() const { return relocation_model_; }
73 
use_mlir_hlo_lowering()74   bool use_mlir_hlo_lowering() const { return use_mlir_hlo_lowering_; }
set_use_mlir_hlo_lowering(bool value)75   void set_use_mlir_hlo_lowering(bool value) { use_mlir_hlo_lowering_ = value; }
76 
77  private:
78   const std::string triple_;
79   const std::string cpu_name_;
80   const std::string features_;
81   const std::string entry_point_name_;
82   const RelocationModel relocation_model_;
83   bool use_mlir_hlo_lowering_ = false;
84 };
85 
86 class CpuAotCompilationResult : public AotCompilationResult {
87  public:
88   CpuAotCompilationResult(
89       ObjectFileData object_file_data,
90       std::vector<cpu_function_runtime::BufferInfo> buffer_infos,
91       int64_t result_buffer_index,
92       std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data);
93   ~CpuAotCompilationResult();
94 
hlo_profile_printer_data()95   HloProfilePrinterData* hlo_profile_printer_data() const {
96     return hlo_profile_printer_data_.get();
97   }
98 
object_file_data()99   const ObjectFileData& object_file_data() const { return object_file_data_; }
buffer_infos()100   const std::vector<cpu_function_runtime::BufferInfo>& buffer_infos() const {
101     return buffer_infos_;
102   }
result_buffer_index()103   int64_t result_buffer_index() const { return result_buffer_index_; }
104 
105  private:
106   // Contains the compiled computation: an object file.
107   const ObjectFileData object_file_data_;
108 
109   // A list of BufferInfo objects describing the buffers used by the XLA
110   // computation.
111   const std::vector<cpu_function_runtime::BufferInfo> buffer_infos_;
112 
113   // Contains which buffer index into |buffer_sizes| was designated to the
114   // result of the computation.  This buffer should be passed into the output
115   // parameter when calling the compiled computation.
116   const int64_t result_buffer_index_;
117 
118   // Contains an instance of HloProfilePrinterData if HLO profiling is enabled,
119   // otherwise is nullptr.
120   std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data_;
121 };
122 
123 // CPU-targeting implementation of the XLA Compiler interface.
124 //
125 // The compiler translates XLA HLO code into LLVM IR and uses LLVM's JIT
126 // infrastructure to create an executable "blob" that can then be returned
127 // wrapped in CpuExecutable and actually invoked.
128 class CpuCompiler : public LLVMCompiler {
129  public:
130   CpuCompiler();
~CpuCompiler()131   ~CpuCompiler() override {}
132 
133   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
134       std::unique_ptr<HloModuleGroup> module_group,
135       std::vector<std::vector<se::StreamExecutor*>> stream_execs,
136       const CompileOptions& options) override;
137 
138   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
139       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
140       const CompileOptions& options) override;
141 
142   StatusOr<std::unique_ptr<BufferAssignment>> AssignBuffers(
143       const HloModule* module) override;
144 
145   StatusOr<std::unique_ptr<Executable>> RunBackend(
146       std::unique_ptr<HloModule> module, se::StreamExecutor* stream_exec,
147       const CompileOptions& options) override;
148 
149   StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
150   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
151                      const AotCompilationOptions& options) override;
152 
153   se::Platform::Id PlatformId() const override;
154 
155   HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override;
156 
157  private:
158   // Initialize the LLVM target.
159   static void InitializeLLVMTarget();
160 
161   // Runs the HLO passes which are necessary for both optimizations and
162   // correctness.
163   Status RunHloPasses(HloModule* module, bool is_aot_compile,
164                       llvm::TargetMachine* target_machine,
165                       bool is_mlir_compile = false);
166 
167   // Runs HLO passes up to and including layout assignment.
168   Status RunHloPassesThroughLayoutAssn(
169       HloModule* module, bool /*is_aot_compile*/,
170       LLVMTargetMachineFeatures* target_machine_features,
171       bool is_mlir_compile = false);
172 
173   // Runs HLO passes after layout assignment.
174   Status RunHloPassesAfterLayoutAssn(
175       HloModule* module, bool is_aot_compile,
176       LLVMTargetMachineFeatures* target_machine_features, bool is_mlir_compile);
177 
178   StatusOr<std::unique_ptr<CpuExecutable>> CompileLegacyCpuExecutable(
179       std::unique_ptr<HloModule> module);
180 
181   CpuCompiler(const CpuCompiler&) = delete;
182   CpuCompiler& operator=(const CpuCompiler&) = delete;
183 };
184 
185 }  // namespace cpu
186 }  // namespace xla
187 
188 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_
189