1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ 18 19 #include "tensorflow/compiler/jit/xla_compilation_cache.h" 20 #include "tensorflow/compiler/jit/xla_device.h" 21 #include "tensorflow/stream_executor/tf_allocator_adapter.h" 22 23 namespace tensorflow { 24 25 // Holds some information about the platform on which an 26 // XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of 27 // abstraction for normal and XLA devices. 28 class XlaPlatformInfo { 29 public: XlaPlatformInfo()30 XlaPlatformInfo() : device_type_("") {} 31 XlaPlatformInfo(XlaPlatformInfo&&) = default; XlaPlatformInfo(const DeviceType device_type,se::Platform::Id platform_id,const XlaDevice::Metadata * xla_device_metadata,std::shared_ptr<se::DeviceMemoryAllocator> device_allocator)32 explicit XlaPlatformInfo( 33 const DeviceType device_type, se::Platform::Id platform_id, 34 const XlaDevice::Metadata* xla_device_metadata, 35 std::shared_ptr<se::DeviceMemoryAllocator> device_allocator) 36 : device_type_(device_type), 37 platform_id_(platform_id), 38 xla_device_metadata_(xla_device_metadata), 39 device_allocator_(device_allocator) {} 40 41 XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; 42 UseMultipleStreams()43 bool UseMultipleStreams() const { 44 return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); 45 } 46 47 // Non-null only when run on an XLA device. custom_allocator()48 std::shared_ptr<se::DeviceMemoryAllocator> custom_allocator() const { 49 return device_allocator_; 50 } 51 device_type()52 DeviceType device_type() const { return device_type_; } 53 54 // This is equal to xla_device_metadata()->platform()->id() if 55 // xla_device_metadata() is not nullptr. platform_id()56 se::Platform::Id platform_id() const { return platform_id_; } 57 58 // This may be null if the op this XlaPlatformInfo is for was not placed on an 59 // XLA device. xla_device_metadata()60 const XlaDevice::Metadata* xla_device_metadata() const { 61 return xla_device_metadata_; 62 } is_on_xla_device()63 bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } 64 65 private: 66 DeviceType device_type_; 67 se::Platform::Id platform_id_; 68 69 // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the 70 // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the 71 // XlaLaunch/_XlaCompile/_XlaRun OpKernel. 72 const XlaDevice::Metadata* xla_device_metadata_; 73 74 // If the op associated with this XlaPlatformInfo is placed on an XLA device 75 // then device_allocator_ is the xla::Backend's memory allocator. If the op 76 // is placed on a regular CPU or GPU device then device_allocator_ is null. 77 // The allocator is of unknown provenance; keep it in a shared pointer to 78 // set an artificial refcount of one. 79 std::shared_ptr<se::DeviceMemoryAllocator> device_allocator_; 80 81 TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); 82 }; 83 84 // Returns a set containing the device ids contained in visible_device_list or 85 // nullopt if it is empty. It returns error in case of malformed configuration 86 // string. 87 StatusOr<std::optional<std::set<int>>> ParseVisibleDeviceList( 88 absl::string_view visible_device_list); 89 90 // Returns created XLA compilation cache. 91 Status BuildXlaCompilationCache(DeviceBase* dev, FunctionLibraryRuntime* flr, 92 const XlaPlatformInfo& platform_info, 93 XlaCompilationCache** cache); 94 95 // Returns information about the platform from kernel context. 96 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); 97 98 // Returns allocator from platform info if non-null, or populate and return a 99 // pointer to the allocator adapter with allocator from context. 100 // 101 // This is necessary because for XLA devices the underlying TF allocator returns 102 // dummy tensors. 103 // 104 // `stream` parameter is nullable when running on host. 105 std::shared_ptr<se::DeviceMemoryAllocator> GetAllocator( 106 DeviceBase* device, se::Stream* stream, 107 const XlaPlatformInfo& platform_info); 108 109 // Returns created options for the XLA compiler, and writes the used allocator 110 // into `tf_allocator_adapter`. 111 XlaCompiler::Options GenerateCompilerOptions( 112 const XlaCompilationCache& cache, 113 const FunctionLibraryRuntime& function_library, DeviceBase* device, 114 se::Stream* stream, const XlaPlatformInfo& platform_info, 115 bool has_ref_vars); 116 117 } // namespace tensorflow 118 119 #endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ 120