xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/compiler.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The compiler API is used by the XLA service to generate executables that
17 // run on a given platform. This is a registry and abstract interface, for
18 // pluggability by the various platforms.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
22 
23 #include <functional>
24 #include <map>
25 #include <memory>
26 #include <string>
27 #include <vector>
28 
29 #include "absl/strings/string_view.h"
30 #include "absl/types/span.h"
31 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
32 #include "tensorflow/compiler/xla/service/buffer_value.h"
33 #include "tensorflow/compiler/xla/service/computation_placer.h"
34 #include "tensorflow/compiler/xla/service/executable.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
38 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
39 #include "tensorflow/compiler/xla/service/logical_buffer.h"
40 #include "tensorflow/compiler/xla/statusor.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/core/platform/protobuf.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44 #include "tensorflow/core/platform/threadpool.h"
45 
46 namespace xla {
47 
48 // The following types are used for ahead of time compilation.
49 
50 // Contains the object file data created as a result of ahead-of-time
51 // computation.
52 using ObjectFileData = std::vector<char>;
53 
54 class Compiler;
55 
56 // Abstract superclass describing the result of an ahead-of-time compilation.
57 class AotCompilationResult {
58  public:
59   AotCompilationResult(const AotCompilationResult&) = delete;
60   AotCompilationResult& operator=(AotCompilationResult const&) = delete;
61 
62   virtual ~AotCompilationResult() = default;
63 
SerializeAsString()64   virtual StatusOr<std::string> SerializeAsString() const {
65     return Unimplemented("SerializeAsString unimplemented.");
66   }
67 
LoadExecutable(Compiler * compiler,se::StreamExecutor * executor)68   virtual StatusOr<std::unique_ptr<Executable>> LoadExecutable(
69       Compiler* compiler, se::StreamExecutor* executor) const {
70     return Unimplemented("LoadExecutable unimplemented.");
71   }
72 
73  protected:
74   AotCompilationResult() = default;
75 };
76 
77 // Abstract superclass describing options to an ahead-of-time compilation.
78 class AotCompilationOptions {
79  public:
80   AotCompilationOptions(const AotCompilationOptions&) = delete;
81   AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
82 
AotCompilationOptions(se::Platform::Id platform_id)83   explicit AotCompilationOptions(se::Platform::Id platform_id)
84       : platform_id_(platform_id), debug_options_(GetDebugOptionsFromFlags()) {}
85   virtual ~AotCompilationOptions() = default;
86 
87   // Returns the ID of the platform to which these options apply.
PlatformId()88   virtual se::Platform::Id PlatformId() const { return platform_id_; }
89 
replica_count()90   virtual int64_t replica_count() const { return 0; }
num_cores()91   virtual int64_t num_cores() const { return 0; }
use_spmd_partitioning()92   virtual bool use_spmd_partitioning() const { return false; }
use_auto_spmd_partitioning()93   virtual bool use_auto_spmd_partitioning() const { return false; }
auto_spmd_partitioning_mesh_shape()94   virtual std::vector<int64_t> auto_spmd_partitioning_mesh_shape() const {
95     return {};
96   }
auto_spmd_partitioning_mesh_ids()97   virtual std::vector<int64_t> auto_spmd_partitioning_mesh_ids() const {
98     return {};
99   }
deduplicate_hlo()100   virtual bool deduplicate_hlo() const { return false; }
matrix_unit_operand_precision()101   virtual PrecisionConfig::Precision matrix_unit_operand_precision() const {
102     return PrecisionConfig::DEFAULT;
103   }
104 
105   // Optional allocator that may be used for allocating temp space on the device
106   // during compilation.
device_allocator()107   se::DeviceMemoryAllocator* device_allocator() const {
108     return device_allocator_;
109   }
set_device_allocator(se::DeviceMemoryAllocator * device_allocator)110   void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) {
111     device_allocator_ = device_allocator;
112   }
113 
debug_options()114   const DebugOptions& debug_options() const { return debug_options_; }
mutable_debug_options()115   DebugOptions* mutable_debug_options() { return &debug_options_; }
116 
has_static_device_assignment()117   bool has_static_device_assignment() const {
118     return static_device_assignment_.has_value();
119   }
static_device_assignment()120   const DeviceAssignment& static_device_assignment() const {
121     CHECK(static_device_assignment_.has_value());
122     return *static_device_assignment_;
123   }
set_static_device_assignment(const DeviceAssignment & device_assignment)124   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
125     static_device_assignment_ = device_assignment;
126   }
127 
fusion_config_collection()128   FusionConfigCollection fusion_config_collection() const {
129     return fusion_config_collection_;
130   }
set_fusion_config_collection(FusionConfigCollection fusion_config_collection)131   void set_fusion_config_collection(
132       FusionConfigCollection fusion_config_collection) {
133     fusion_config_collection_ = fusion_config_collection;
134   }
135 
fusion_config()136   const std::vector<std::vector<bool>>& fusion_config() const {
137     return fusion_config_;
138   }
set_fusion_config(const std::vector<std::vector<bool>> & fusion_config)139   void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
140     fusion_config_ = fusion_config;
141   }
142 
executor()143   se::StreamExecutor* executor() const { return executor_; }
set_executor(se::StreamExecutor * executor)144   void set_executor(se::StreamExecutor* executor) { executor_ = executor; }
145 
146   // Optional profile_version and cache key may be used to trigger recompilation
147   // when a compilation cache is used.
profile_version()148   int64_t profile_version() const { return profile_version_; }
set_profile_version(int64_t profile_version)149   void set_profile_version(int64_t profile_version) {
150     profile_version_ = profile_version;
151   }
152 
cache_key()153   absl::string_view cache_key() const { return cache_key_; }
set_cache_key(absl::string_view cache_key)154   void set_cache_key(absl::string_view cache_key) {
155     cache_key_ = std::string(cache_key);
156   }
157 
run_backend_only()158   bool run_backend_only() const { return run_backend_only_; }
set_run_backend_only(bool run_backend_only)159   void set_run_backend_only(bool run_backend_only) {
160     run_backend_only_ = run_backend_only;
161   }
162 
sanitize_dataflow()163   bool sanitize_dataflow() const { return sanitize_dataflow_; }
set_sanitize_dataflow(bool sanitize_dataflow)164   void set_sanitize_dataflow(bool sanitize_dataflow) {
165     sanitize_dataflow_ = sanitize_dataflow;
166   }
167 
sanitize_abilists_dataflow()168   const std::vector<std::string>& sanitize_abilists_dataflow() const {
169     return sanitize_abilists_dataflow_;
170   }
set_sanitize_abilists_dataflow(const std::vector<std::string> & abilists)171   void set_sanitize_abilists_dataflow(
172       const std::vector<std::string>& abilists) {
173     sanitize_abilists_dataflow_ = abilists;
174   }
175 
176  protected:
177   AotCompilationOptions();
178 
179  private:
180   se::Platform::Id platform_id_;
181   se::DeviceMemoryAllocator* device_allocator_ = nullptr;
182   DebugOptions debug_options_;
183   std::optional<DeviceAssignment> static_device_assignment_;
184   std::vector<std::vector<bool>> fusion_config_;
185   FusionConfigCollection fusion_config_collection_ =
186       FusionConfigCollection::kOff;
187   se::StreamExecutor* executor_ = nullptr;
188   int64_t profile_version_ = 0;
189   std::string cache_key_;
190   bool run_backend_only_ = false;
191   bool sanitize_dataflow_ = false;
192   std::vector<std::string> sanitize_abilists_dataflow_;
193 };
194 
195 // Abstract superclass describing metadata produced during ahead-of-time
196 // compilation.
197 class AotCompilationMetadata {
198  public:
199   AotCompilationMetadata(const AotCompilationMetadata&) = delete;
200   AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
ToString()201   virtual std::string ToString() const { return ""; }
202   virtual ~AotCompilationMetadata() = default;
203 
204  protected:
205   AotCompilationMetadata() = default;
206 };
207 
208 // Abstract compiler interface that is subclassed for compilation on a
209 // particular platform.
210 //
211 // The compiler ties together high level optimization (HLO) and low level
212 // optimization (LLO) / codegen (CG) to generate efficient executables for the
213 // target platform.
214 //
215 // The platform-based compiler singletons are registered via module initializers
216 // in their corresponding XLA compiler libraries, and are registered via the
217 // RegisterCompilerFactory API below.
218 //
219 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple
220 // XLA clients may be requesting compilation concurrently for a given
221 // platform.
222 class Compiler {
223  public:
224   struct CompileOptions {
225     // If device_allocator is not null, the compiler may use it to allocate temp
226     // space on the device for use during compilation.  For example, the
227     // compiler may allocate buffers on the device and then run variants of a
228     // given algorithm over those buffers, to see which variant is fastest.  Any
229     // space allocated will be deallocated before the compilation returns.
230     se::DeviceMemoryAllocator* device_allocator = nullptr;
231 
232     // An optional thread pool for parallel compilation.
233     tensorflow::thread::ThreadPool* thread_pool = nullptr;
234   };
235 
~Compiler()236   virtual ~Compiler() {}
237 
238   // Returns the ID of the platform that this compiler targets.
239   virtual se::Platform::Id PlatformId() const = 0;
240 
241   // Runs Hlo passes to optimize the given Hlo module, returns the optimized
242   // module.
243   virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
244       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
245       const CompileOptions& options) = 0;
RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)246   StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
247       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
248       se::DeviceMemoryAllocator* device_allocator) {
249     return RunHloPasses(std::move(module), executor,
250                         CompileOptions{device_allocator});
251   }
252 
253   // Performs scheduling and buffer assignment and returns the buffer
254   // assignments.
255   // The returned 'BufferAssignment' retains a pointer to the 'HloModule', so
256   // the module must live at least as long as the buffer assignments.
AssignBuffers(const HloModule * module)257   virtual StatusOr<std::unique_ptr<BufferAssignment>> AssignBuffers(
258       const HloModule* module) {
259     return Unimplemented("This compiler does not support this method");
260   }
261 
262   // Compiles the HLO module for execution on a device given by the executor,
263   // and returns an executable object or an error status. No HLO passes are
264   // applied to module. Generally a module should be passed through RunHloPasses
265   // prior to calling this method because some HLO passes are required for
266   // correctness. Takes ownership of the HLO module.
267   //
268   // The compiler may optionally specialize to the individual device
269   // (not just type of device) indicated by the executor.
270   virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
271       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
272       const CompileOptions& options) = 0;
RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)273   StatusOr<std::unique_ptr<Executable>> RunBackend(
274       std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
275       se::DeviceMemoryAllocator* device_allocator) {
276     return RunBackend(std::move(module), executor,
277                       CompileOptions{device_allocator});
278   }
279 
280   // Returns a (deserialized) AotCompilationResult from a serialized
281   // AotCompilationResult.
282   virtual StatusOr<std::unique_ptr<AotCompilationResult>>
LoadAotCompilationResult(const std::string & serialized_aot_result)283   LoadAotCompilationResult(const std::string& serialized_aot_result) {
284     return Unimplemented("LoadAotCompilationResult unimplemented.");
285   }
286 
287   // Compiles a set of HLO modules that can run in parallel, potentially
288   // communicating data between the modules, and returns a corresponding
289   // sequence of executable objects.
290   //
291   // TODO(b/68666782): Remove this method after adding support for multiple
292   // modules to RunHloPasses and RunBackends.
293   virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
294       std::unique_ptr<HloModuleGroup> module_group,
295       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
296       const CompileOptions& options) = 0;
Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,se::DeviceMemoryAllocator * device_allocator)297   StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
298       std::unique_ptr<HloModuleGroup> module_group,
299       std::vector<std::vector<se::StreamExecutor*>> stream_exec,
300       se::DeviceMemoryAllocator* device_allocator) {
301     return Compile(std::move(module_group), stream_exec,
302                    CompileOptions{device_allocator});
303   }
304 
305   // Returns the backend configurations that the backend will consider for the
306   // given HLO. Returns no configurations if the backend does not support
307   // configurations for the given HLO.
308   //
309   // The stream executor is passed in to provide information about the hardware
310   // that the backend configurations would be targeting.
311   virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>>
312   ComputeBackendConfigs(const HloInstruction& hlo,
313                         se::StreamExecutor* executor) const;
314 
315   // Returns the backend configuration that the backend chooses by default for
316   // the given HLO. Returns no configuration if the backend does not support
317   // configurations for the given HLO.
318   //
319   // The stream executor is passed in to provide information about the hardware
320   // that the backend configurations would be targeting.
321   virtual std::unique_ptr<tensorflow::protobuf::Message>
322   ComputeDefaultBackendConfig(const HloInstruction& hlo,
323                               se::StreamExecutor* executor) const;
324 
325   // Compiles the HLO module group for ahead-of-time execution.  This is
326   // intended for use in static compilation.
327   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
328   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
329                      const AotCompilationOptions& options) = 0;
330 
331   // Similar to CompileAheadOfTime above but AotCompilationMetadata
332   // has an argument that can be populated during compilation.
333   virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
334   CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
335                      const AotCompilationOptions& options,
336                      std::unique_ptr<AotCompilationMetadata>* metadata);
337 
338   /////
339   // The Compiler class also serves as a point to register compiler objects
340   // for the various platforms.
341 
342   using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
343 
344   // Registers the compiler singleton for the platform. This is assumed to
345   // be a singleton, so no ownership is transferred.
346   //
347   // Precondition: a platform kind must not be registered more than once.
348   static void RegisterCompilerFactory(se::Platform::Id platform_id,
349                                       CompilerFactory compiler_factory);
350 
351   // Returns the compiler singleton pointer if it is available for the given
352   // platform, or an error status if it is not.
353   static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform);
354 
355   // Returns a function that computes the size in bytes of the logical
356   // buffer that contains a shape.
357   virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
358 
359   // Returns a function that computes the size in bytes of a given
360   // logical buffer.
BufferSizeBytesFunction()361   std::function<int64_t(const BufferValue&)> BufferSizeBytesFunction() {
362     HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
363     return [shape_size](const BufferValue& buffer) {
364       return shape_size(buffer.shape());
365     };
366   }
367 
DefaultDeviceShapeRepresentation(const Shape & shape)368   virtual Shape DefaultDeviceShapeRepresentation(const Shape& shape) const {
369     return shape;
370   }
371 
372  private:
373   // Mutex that guards the platform-compiler map.
374   static absl::Mutex platform_compiler_mutex_;
375 
376   // Map from platform kind to compiler factory.
377   static std::map<se::Platform::Id, CompilerFactory>*
378   GetPlatformCompilerFactories();
379 
380   // Map from platform kind to compiler instance, if we made one already (based
381   // on the factories above).
382   static std::map<se::Platform::Id, std::unique_ptr<Compiler>>*
383   GetPlatformCompilers();
384 };
385 
386 }  // namespace xla
387 
388 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
389