1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // The compiler API is used by the XLA service to generate executables that 17 // run on a given platform. This is a registry and abstract interface, for 18 // pluggability by the various platforms. 19 20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 21 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 22 23 #include <functional> 24 #include <map> 25 #include <memory> 26 #include <string> 27 #include <vector> 28 29 #include "absl/strings/string_view.h" 30 #include "absl/types/span.h" 31 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 32 #include "tensorflow/compiler/xla/service/buffer_value.h" 33 #include "tensorflow/compiler/xla/service/computation_placer.h" 34 #include "tensorflow/compiler/xla/service/executable.h" 35 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 36 #include "tensorflow/compiler/xla/service/hlo_module.h" 37 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 38 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 39 #include "tensorflow/compiler/xla/service/logical_buffer.h" 40 #include "tensorflow/compiler/xla/statusor.h" 41 #include "tensorflow/compiler/xla/types.h" 42 #include "tensorflow/core/platform/protobuf.h" 43 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 44 #include "tensorflow/core/platform/threadpool.h" 45 46 namespace xla { 47 48 // The following types are used for ahead of time compilation. 49 50 // Contains the object file data created as a result of ahead-of-time 51 // computation. 52 using ObjectFileData = std::vector<char>; 53 54 class Compiler; 55 56 // Abstract superclass describing the result of an ahead-of-time compilation. 57 class AotCompilationResult { 58 public: 59 AotCompilationResult(const AotCompilationResult&) = delete; 60 AotCompilationResult& operator=(AotCompilationResult const&) = delete; 61 62 virtual ~AotCompilationResult() = default; 63 SerializeAsString()64 virtual StatusOr<std::string> SerializeAsString() const { 65 return Unimplemented("SerializeAsString unimplemented."); 66 } 67 LoadExecutable(Compiler * compiler,se::StreamExecutor * executor)68 virtual StatusOr<std::unique_ptr<Executable>> LoadExecutable( 69 Compiler* compiler, se::StreamExecutor* executor) const { 70 return Unimplemented("LoadExecutable unimplemented."); 71 } 72 73 protected: 74 AotCompilationResult() = default; 75 }; 76 77 // Abstract superclass describing options to an ahead-of-time compilation. 78 class AotCompilationOptions { 79 public: 80 AotCompilationOptions(const AotCompilationOptions&) = delete; 81 AotCompilationOptions& operator=(AotCompilationOptions const&) = delete; 82 AotCompilationOptions(se::Platform::Id platform_id)83 explicit AotCompilationOptions(se::Platform::Id platform_id) 84 : platform_id_(platform_id), debug_options_(GetDebugOptionsFromFlags()) {} 85 virtual ~AotCompilationOptions() = default; 86 87 // Returns the ID of the platform to which these options apply. PlatformId()88 virtual se::Platform::Id PlatformId() const { return platform_id_; } 89 replica_count()90 virtual int64_t replica_count() const { return 0; } num_cores()91 virtual int64_t num_cores() const { return 0; } use_spmd_partitioning()92 virtual bool use_spmd_partitioning() const { return false; } use_auto_spmd_partitioning()93 virtual bool use_auto_spmd_partitioning() const { return false; } auto_spmd_partitioning_mesh_shape()94 virtual std::vector<int64_t> auto_spmd_partitioning_mesh_shape() const { 95 return {}; 96 } auto_spmd_partitioning_mesh_ids()97 virtual std::vector<int64_t> auto_spmd_partitioning_mesh_ids() const { 98 return {}; 99 } deduplicate_hlo()100 virtual bool deduplicate_hlo() const { return false; } matrix_unit_operand_precision()101 virtual PrecisionConfig::Precision matrix_unit_operand_precision() const { 102 return PrecisionConfig::DEFAULT; 103 } 104 105 // Optional allocator that may be used for allocating temp space on the device 106 // during compilation. device_allocator()107 se::DeviceMemoryAllocator* device_allocator() const { 108 return device_allocator_; 109 } set_device_allocator(se::DeviceMemoryAllocator * device_allocator)110 void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) { 111 device_allocator_ = device_allocator; 112 } 113 debug_options()114 const DebugOptions& debug_options() const { return debug_options_; } mutable_debug_options()115 DebugOptions* mutable_debug_options() { return &debug_options_; } 116 has_static_device_assignment()117 bool has_static_device_assignment() const { 118 return static_device_assignment_.has_value(); 119 } static_device_assignment()120 const DeviceAssignment& static_device_assignment() const { 121 CHECK(static_device_assignment_.has_value()); 122 return *static_device_assignment_; 123 } set_static_device_assignment(const DeviceAssignment & device_assignment)124 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 125 static_device_assignment_ = device_assignment; 126 } 127 fusion_config_collection()128 FusionConfigCollection fusion_config_collection() const { 129 return fusion_config_collection_; 130 } set_fusion_config_collection(FusionConfigCollection fusion_config_collection)131 void set_fusion_config_collection( 132 FusionConfigCollection fusion_config_collection) { 133 fusion_config_collection_ = fusion_config_collection; 134 } 135 fusion_config()136 const std::vector<std::vector<bool>>& fusion_config() const { 137 return fusion_config_; 138 } set_fusion_config(const std::vector<std::vector<bool>> & fusion_config)139 void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) { 140 fusion_config_ = fusion_config; 141 } 142 executor()143 se::StreamExecutor* executor() const { return executor_; } set_executor(se::StreamExecutor * executor)144 void set_executor(se::StreamExecutor* executor) { executor_ = executor; } 145 146 // Optional profile_version and cache key may be used to trigger recompilation 147 // when a compilation cache is used. profile_version()148 int64_t profile_version() const { return profile_version_; } set_profile_version(int64_t profile_version)149 void set_profile_version(int64_t profile_version) { 150 profile_version_ = profile_version; 151 } 152 cache_key()153 absl::string_view cache_key() const { return cache_key_; } set_cache_key(absl::string_view cache_key)154 void set_cache_key(absl::string_view cache_key) { 155 cache_key_ = std::string(cache_key); 156 } 157 run_backend_only()158 bool run_backend_only() const { return run_backend_only_; } set_run_backend_only(bool run_backend_only)159 void set_run_backend_only(bool run_backend_only) { 160 run_backend_only_ = run_backend_only; 161 } 162 sanitize_dataflow()163 bool sanitize_dataflow() const { return sanitize_dataflow_; } set_sanitize_dataflow(bool sanitize_dataflow)164 void set_sanitize_dataflow(bool sanitize_dataflow) { 165 sanitize_dataflow_ = sanitize_dataflow; 166 } 167 sanitize_abilists_dataflow()168 const std::vector<std::string>& sanitize_abilists_dataflow() const { 169 return sanitize_abilists_dataflow_; 170 } set_sanitize_abilists_dataflow(const std::vector<std::string> & abilists)171 void set_sanitize_abilists_dataflow( 172 const std::vector<std::string>& abilists) { 173 sanitize_abilists_dataflow_ = abilists; 174 } 175 176 protected: 177 AotCompilationOptions(); 178 179 private: 180 se::Platform::Id platform_id_; 181 se::DeviceMemoryAllocator* device_allocator_ = nullptr; 182 DebugOptions debug_options_; 183 std::optional<DeviceAssignment> static_device_assignment_; 184 std::vector<std::vector<bool>> fusion_config_; 185 FusionConfigCollection fusion_config_collection_ = 186 FusionConfigCollection::kOff; 187 se::StreamExecutor* executor_ = nullptr; 188 int64_t profile_version_ = 0; 189 std::string cache_key_; 190 bool run_backend_only_ = false; 191 bool sanitize_dataflow_ = false; 192 std::vector<std::string> sanitize_abilists_dataflow_; 193 }; 194 195 // Abstract superclass describing metadata produced during ahead-of-time 196 // compilation. 197 class AotCompilationMetadata { 198 public: 199 AotCompilationMetadata(const AotCompilationMetadata&) = delete; 200 AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete; ToString()201 virtual std::string ToString() const { return ""; } 202 virtual ~AotCompilationMetadata() = default; 203 204 protected: 205 AotCompilationMetadata() = default; 206 }; 207 208 // Abstract compiler interface that is subclassed for compilation on a 209 // particular platform. 210 // 211 // The compiler ties together high level optimization (HLO) and low level 212 // optimization (LLO) / codegen (CG) to generate efficient executables for the 213 // target platform. 214 // 215 // The platform-based compiler singletons are registered via module initializers 216 // in their corresponding XLA compiler libraries, and are registered via the 217 // RegisterCompilerFactory API below. 218 // 219 // Thread-safety: subclasses of Compiler must be thread-safe, as multiple 220 // XLA clients may be requesting compilation concurrently for a given 221 // platform. 222 class Compiler { 223 public: 224 struct CompileOptions { 225 // If device_allocator is not null, the compiler may use it to allocate temp 226 // space on the device for use during compilation. For example, the 227 // compiler may allocate buffers on the device and then run variants of a 228 // given algorithm over those buffers, to see which variant is fastest. Any 229 // space allocated will be deallocated before the compilation returns. 230 se::DeviceMemoryAllocator* device_allocator = nullptr; 231 232 // An optional thread pool for parallel compilation. 233 tensorflow::thread::ThreadPool* thread_pool = nullptr; 234 }; 235 ~Compiler()236 virtual ~Compiler() {} 237 238 // Returns the ID of the platform that this compiler targets. 239 virtual se::Platform::Id PlatformId() const = 0; 240 241 // Runs Hlo passes to optimize the given Hlo module, returns the optimized 242 // module. 243 virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 244 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 245 const CompileOptions& options) = 0; RunHloPasses(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)246 StatusOr<std::unique_ptr<HloModule>> RunHloPasses( 247 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 248 se::DeviceMemoryAllocator* device_allocator) { 249 return RunHloPasses(std::move(module), executor, 250 CompileOptions{device_allocator}); 251 } 252 253 // Performs scheduling and buffer assignment and returns the buffer 254 // assignments. 255 // The returned 'BufferAssignment' retains a pointer to the 'HloModule', so 256 // the module must live at least as long as the buffer assignments. AssignBuffers(const HloModule * module)257 virtual StatusOr<std::unique_ptr<BufferAssignment>> AssignBuffers( 258 const HloModule* module) { 259 return Unimplemented("This compiler does not support this method"); 260 } 261 262 // Compiles the HLO module for execution on a device given by the executor, 263 // and returns an executable object or an error status. No HLO passes are 264 // applied to module. Generally a module should be passed through RunHloPasses 265 // prior to calling this method because some HLO passes are required for 266 // correctness. Takes ownership of the HLO module. 267 // 268 // The compiler may optionally specialize to the individual device 269 // (not just type of device) indicated by the executor. 270 virtual StatusOr<std::unique_ptr<Executable>> RunBackend( 271 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 272 const CompileOptions& options) = 0; RunBackend(std::unique_ptr<HloModule> module,se::StreamExecutor * executor,se::DeviceMemoryAllocator * device_allocator)273 StatusOr<std::unique_ptr<Executable>> RunBackend( 274 std::unique_ptr<HloModule> module, se::StreamExecutor* executor, 275 se::DeviceMemoryAllocator* device_allocator) { 276 return RunBackend(std::move(module), executor, 277 CompileOptions{device_allocator}); 278 } 279 280 // Returns a (deserialized) AotCompilationResult from a serialized 281 // AotCompilationResult. 282 virtual StatusOr<std::unique_ptr<AotCompilationResult>> LoadAotCompilationResult(const std::string & serialized_aot_result)283 LoadAotCompilationResult(const std::string& serialized_aot_result) { 284 return Unimplemented("LoadAotCompilationResult unimplemented."); 285 } 286 287 // Compiles a set of HLO modules that can run in parallel, potentially 288 // communicating data between the modules, and returns a corresponding 289 // sequence of executable objects. 290 // 291 // TODO(b/68666782): Remove this method after adding support for multiple 292 // modules to RunHloPasses and RunBackends. 293 virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 294 std::unique_ptr<HloModuleGroup> module_group, 295 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 296 const CompileOptions& options) = 0; Compile(std::unique_ptr<HloModuleGroup> module_group,std::vector<std::vector<se::StreamExecutor * >> stream_exec,se::DeviceMemoryAllocator * device_allocator)297 StatusOr<std::vector<std::unique_ptr<Executable>>> Compile( 298 std::unique_ptr<HloModuleGroup> module_group, 299 std::vector<std::vector<se::StreamExecutor*>> stream_exec, 300 se::DeviceMemoryAllocator* device_allocator) { 301 return Compile(std::move(module_group), stream_exec, 302 CompileOptions{device_allocator}); 303 } 304 305 // Returns the backend configurations that the backend will consider for the 306 // given HLO. Returns no configurations if the backend does not support 307 // configurations for the given HLO. 308 // 309 // The stream executor is passed in to provide information about the hardware 310 // that the backend configurations would be targeting. 311 virtual std::vector<std::unique_ptr<tensorflow::protobuf::Message>> 312 ComputeBackendConfigs(const HloInstruction& hlo, 313 se::StreamExecutor* executor) const; 314 315 // Returns the backend configuration that the backend chooses by default for 316 // the given HLO. Returns no configuration if the backend does not support 317 // configurations for the given HLO. 318 // 319 // The stream executor is passed in to provide information about the hardware 320 // that the backend configurations would be targeting. 321 virtual std::unique_ptr<tensorflow::protobuf::Message> 322 ComputeDefaultBackendConfig(const HloInstruction& hlo, 323 se::StreamExecutor* executor) const; 324 325 // Compiles the HLO module group for ahead-of-time execution. This is 326 // intended for use in static compilation. 327 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 328 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 329 const AotCompilationOptions& options) = 0; 330 331 // Similar to CompileAheadOfTime above but AotCompilationMetadata 332 // has an argument that can be populated during compilation. 333 virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>> 334 CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group, 335 const AotCompilationOptions& options, 336 std::unique_ptr<AotCompilationMetadata>* metadata); 337 338 ///// 339 // The Compiler class also serves as a point to register compiler objects 340 // for the various platforms. 341 342 using CompilerFactory = std::function<std::unique_ptr<Compiler>()>; 343 344 // Registers the compiler singleton for the platform. This is assumed to 345 // be a singleton, so no ownership is transferred. 346 // 347 // Precondition: a platform kind must not be registered more than once. 348 static void RegisterCompilerFactory(se::Platform::Id platform_id, 349 CompilerFactory compiler_factory); 350 351 // Returns the compiler singleton pointer if it is available for the given 352 // platform, or an error status if it is not. 353 static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform); 354 355 // Returns a function that computes the size in bytes of the logical 356 // buffer that contains a shape. 357 virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0; 358 359 // Returns a function that computes the size in bytes of a given 360 // logical buffer. BufferSizeBytesFunction()361 std::function<int64_t(const BufferValue&)> BufferSizeBytesFunction() { 362 HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction(); 363 return [shape_size](const BufferValue& buffer) { 364 return shape_size(buffer.shape()); 365 }; 366 } 367 DefaultDeviceShapeRepresentation(const Shape & shape)368 virtual Shape DefaultDeviceShapeRepresentation(const Shape& shape) const { 369 return shape; 370 } 371 372 private: 373 // Mutex that guards the platform-compiler map. 374 static absl::Mutex platform_compiler_mutex_; 375 376 // Map from platform kind to compiler factory. 377 static std::map<se::Platform::Id, CompilerFactory>* 378 GetPlatformCompilerFactories(); 379 380 // Map from platform kind to compiler instance, if we made one already (based 381 // on the factories above). 382 static std::map<se::Platform::Id, std::unique_ptr<Compiler>>* 383 GetPlatformCompilers(); 384 }; 385 386 } // namespace xla 387 388 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_ 389