1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 18 19 #include <map> 20 #include <set> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/literal.h" 26 #include "tensorflow/compiler/xla/service/executable.h" 27 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla_data.pb.h" 31 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 32 #include "tensorflow/stream_executor/device_memory.h" 33 34 namespace xla { 35 36 // The TransferManager interface lets backends provide platform-specific 37 // mechanisms for constructing literals from given device memory handles. 38 // This lets each platform customize how literals are transferred to/from the 39 // device in terms of padding, leading dimension, etc. 40 class TransferManager { 41 public: ~TransferManager()42 virtual ~TransferManager() {} 43 44 // Returns the ID of the platform that this transfer manager acts on. 45 virtual se::Platform::Id PlatformId() const = 0; 46 47 // Returns the shape of the on-device representation for the given shape on 48 // the host. This is intended for use with ShapedBuffer where buffers are 49 // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user 50 // needing to consider device-specific behaviors. HostShapeToDeviceShape(const Shape & host_shape)51 virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { 52 // Strips off any preexisting tiling or memory space information. 53 // TODO(phawkins): fix clients not to including tiling or memory space 54 // information in shapes passed to this function and turn this into an 55 // assertion. 56 return ShapeUtil::DeviceShapeToHostShape(host_shape); 57 } 58 59 // Base class for specifying platform specific transfer metadata that can be 60 // used to tell the underlying implementation to perform specific optimization 61 // to a transfer. Actual metadata passed to supported transfer methods should 62 // subclass this class. 63 class TransferMetadata { 64 public: 65 virtual ~TransferMetadata() = 0; 66 }; 67 // Returns a literal containing the data held in the given ShapedBuffer 68 // using the provided executor. This operation is performed synchronously 69 // without waiting for any other operation on a stream to complete. 70 // 71 // This function should be avoided in favor of the asynchronous version below. 72 // 73 // Optionally caller can specify platform-specific transfer metadata that 74 // tells the actual implementation to do something special. 75 virtual StatusOr<Literal> TransferLiteralFromDevice( 76 se::Stream* stream, const ShapedBuffer& device_buffer, 77 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)78 StatusOr<Literal> TransferLiteralFromDevice( 79 se::Stream* stream, const ShapedBuffer& device_buffer) { 80 return TransferLiteralFromDevice(stream, device_buffer, nullptr); 81 } 82 virtual Status TransferLiteralFromDevice( 83 se::Stream* stream, const ShapedBuffer& device_buffer, 84 const MutableBorrowingLiteral& literal, 85 const TransferMetadata* transfer_metadata); TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)86 Status TransferLiteralFromDevice(se::Stream* stream, 87 const ShapedBuffer& device_buffer, 88 const MutableBorrowingLiteral& literal) { 89 return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr); 90 } 91 92 // Begins transferring a literal containing the data held in the given 93 // ShapedBuffer using the provided executor. 94 // 95 // This operation is performed asynchronously on the given stream. It returns 96 // once the transfer is enqueued. 'done' is invoked with the result when 97 // complete. 98 // 99 // device_buffer is copied by reference and must live at least until done() is 100 // invoked. 101 // 102 // Optionally caller can specify platform-specific transfer metadata that 103 // tells the actual implementation to do something special. 104 virtual void TransferLiteralFromDevice( 105 se::Stream* stream, const ShapedBuffer& device_buffer, 106 MutableBorrowingLiteral literal, std::function<void(Status)> done, 107 const TransferMetadata* transfer_metadata) = 0; TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)108 void TransferLiteralFromDevice(se::Stream* stream, 109 const ShapedBuffer& device_buffer, 110 MutableBorrowingLiteral literal, 111 std::function<void(Status)> done) { 112 return TransferLiteralFromDevice(stream, device_buffer, literal, done, 113 nullptr); 114 } 115 116 // Transfers the given literal into the previously allocated device memory 117 // represented by the given ShapedBuffer using the given executor. The shape 118 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 119 // but need not have the same layout. 120 // 121 // This operation is performed synchronously without waiting for any other 122 // operation on a stream to complete. This function should be avoided in favor 123 // of the asynchronous version below. 124 // 125 // Optionally caller can specify platform-specific transfer metadata that 126 // tells the actual implementation to do something special. 127 virtual Status TransferLiteralToDevice( 128 se::Stream* stream, const LiteralSlice& literal, 129 const ShapedBuffer& device_buffer, 130 const TransferMetadata* transfer_metadata); TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)131 Status TransferLiteralToDevice(se::Stream* stream, 132 const LiteralSlice& literal, 133 const ShapedBuffer& device_buffer) { 134 return TransferLiteralToDevice(stream, literal, device_buffer, nullptr); 135 } 136 137 // Transfers the given literal into the previously allocated device memory 138 // represented by the given ShapedBuffer using the given executor. The shape 139 // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, 140 // but need not have the same layout. 141 // 142 // This operation is performed asynchronously on the given stream. It returns 143 // once the transfer is enqueued, and may return before the transfer has 144 // completed. 145 // 146 // The caller may free the data structures 'literal' and 'device_buffer' 147 // immediately after this function returns, however their constituent buffers 148 // on both host and device must remain valid until the enqueued transfer has 149 // completed on 'stream'. 150 // 151 // Optionally caller can specify platform-specific transfer metadata that 152 // tells the actual implementation to do something special. 153 virtual Status TransferLiteralToDeviceAsync( 154 se::Stream* stream, const LiteralSlice& literal, 155 const ShapedBuffer& device_buffer, 156 const TransferMetadata* transfer_metadata) = 0; TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)157 Status TransferLiteralToDeviceAsync(se::Stream* stream, 158 const LiteralSlice& literal, 159 const ShapedBuffer& device_buffer) { 160 return TransferLiteralToDeviceAsync(stream, literal, device_buffer, 161 nullptr); 162 } 163 164 // Convenience methods for transferring an array to or from the device at a 165 // known address. This avoids having to construct a ShapedBuffer just to 166 // transfer an array at a known address. 167 // 168 // Optionally caller can specify platform-specific transfer metadata that 169 // tells the actual implementation to do something special. 170 Status TransferArrayToDevice( 171 se::Stream* stream, const LiteralSlice& literal, 172 const se::DeviceMemoryBase& dest, 173 const TransferMetadata* transfer_metadata = nullptr); 174 void TransferArrayFromDevice( 175 se::Stream* stream, const Shape& shape, 176 const se::DeviceMemoryBase& source, 177 const MutableBorrowingLiteral& literal, std::function<void(Status)> done, 178 const TransferMetadata* transfer_metadata = nullptr); 179 180 Status TransferArrayToDeviceAsync( 181 se::Stream* stream, const LiteralSlice& literal, 182 const se::DeviceMemoryBase& dest, 183 const TransferMetadata* transfer_metadata = nullptr); 184 StatusOr<Literal> TransferArrayFromDevice( 185 se::Stream* stream, const Shape& shape, 186 const se::DeviceMemoryBase& source, 187 const TransferMetadata* transfer_metadata = nullptr); 188 189 // Read from a device buffer and update the dynamic dimension sizes of 190 // `host_shape` and `device_shape`. The function takes in bounded dynamic 191 // shapes, and returns static shapes with dynamic shapes updated. 192 // The shape of the buffer also have to be compatible with the host shape and 193 // device shape. 194 virtual Status ReadDynamicShapes(se::Stream* stream, 195 ShapedBuffer* device_buffer, 196 Shape* device_shape); 197 198 // Transfers the given literal into the Infeed interface of the device, 199 // using the given executor. 200 virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, 201 const LiteralSlice& literal) = 0; 202 203 // Transfers the given literal from the Outfeed interface of the device, 204 // using the given executor. The shape and layout are determined by the 205 // shape and layout of `literal`. 206 virtual Status TransferLiteralFromOutfeed( 207 se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0; 208 209 // Resets the devices associated with this transfer manager. 210 virtual Status ResetDevices( 211 absl::Span<se::StreamExecutor* const> executor) = 0; 212 213 // Given an allocated ShapedBuffer, constructs the tuple index table(s) in 214 // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the 215 // ShapedBuffer is array-shaped this method does nothing. 216 Status WriteTupleIndexTables(se::Stream* stream, 217 const ShapedBuffer& device_buffer); 218 Status WriteTupleIndexTablesAsync(se::Stream* stream, 219 const ShapedBuffer& device_buffer); 220 221 // Writes a tuple index buffer for the root of 'device_buffer', which must 222 // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer, 223 // rather than writing all subbuffers. This method is always asynchronous. 224 Status WriteRootTupleIndexTable(se::Stream* stream, 225 const ShapedBuffer& device_buffer); 226 Status WriteRootTupleIndexTable( 227 se::Stream* stream, 228 const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree); 229 230 // Determines the byte size requirement for the given shape on the underlying 231 // architecture. This will be used to allocate an appropriately sized memory 232 // region for a host-to-device transfer. 233 virtual int64_t GetByteSizeRequirement(const Shape& shape) const = 0; 234 235 // Chooses a compact layout for 'shape', ignoring any existing layout on 236 // 'shape'. What "reasonable" means is left up to the backend. The 237 // intended use case is to choose a layout that avoids excessive padding on 238 // devices that have tiled memory architectures. 239 // The default implementation always picks a default (major-to-minor) layout. 240 // Fails if 'shape' cannot be represented by the device. 241 virtual StatusOr<Shape> ChooseCompactLayoutForShape( 242 const Shape& host_shape) const; 243 244 // For the given shape, chooses a layout for infeed. The returned shape 245 // has the same dimensions as the original shape, and only the layout is 246 // changed. 247 virtual Shape ChooseGoodInfeedLayout(const Shape& shape) const; 248 249 typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn; 250 251 // Allocates a ScopedShapedBuffer which can hold data with the given on-host 252 // shape. The on-device shape may be different as indicated by 253 // HostShapeToDeviceShape. 254 StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer( 255 const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator, 256 int device_ordinal, 257 DeviceShapeRepresentationFn shape_representation_fn = nullptr); 258 259 // The given ShapedBuffer holds a handle to allocated memory, but it is not 260 // in the general case legal to immediately copy or access that allocated 261 // memory because queued operations on the device may alias that memory. 262 // Memory ordering is enforced by the Stream's happens-before relationship 263 // which allows eager deallocation and reallocation of buffers host-side even 264 // if the device hasn't finished with them. 265 // 266 // In certain cases, it can be known that a ShapedBuffer does not have any 267 // conflicting accesses on the device and thus is eligible to be accessed at 268 // any time from the host. 269 // 270 // This function returns true if device_buffer can be accessed immediately 271 // without waiting for the Stream's previously enqueued items. This only 272 // returns true if all subbuffers in device_buffer can be accessed 273 // immediately. CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)274 virtual bool CanShapedBufferBeAccessedNow( 275 se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const { 276 return false; 277 } 278 279 // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer. CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer)280 virtual bool CanBufferBeAccessedNow( 281 se::StreamExecutor* executor, 282 const se::DeviceMemoryBase& device_buffer) const { 283 return false; 284 } 285 286 ///// 287 // The TransferManager class also serves as a point to register objects for 288 // the various platforms. 289 290 // Registers the TransferManager singleton for the platform kind. This is 291 // assumed to be a singleton, so no ownership is transferred. 292 // 293 // Precondition: a platform kind must not be registered more than once. 294 typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)(); 295 static void RegisterTransferManager( 296 se::Platform::Id platform_id, 297 TransferManagerCreationFunction transfer_manager); 298 299 // Returns the transfer manager singleton pointer if it is available for the 300 // given platform, or an error status if it is not. 301 static StatusOr<TransferManager*> GetForPlatform( 302 const se::Platform* platform); 303 304 // Writes the given device-memory pointers in 'elements' to the given region 305 // to construct a tuple index table in the platform-specific tuple 306 // representation. 307 virtual Status WriteSingleTupleIndexTable( 308 se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements, 309 const Shape& shape, se::DeviceMemoryBase* region) = 0; 310 311 protected: 312 // Transfer a memory block of the given size from the device source into the 313 // 'destination' buffer. 314 // 315 // size is the size to transfer to destination in bytes. 316 virtual Status TransferBufferFromDevice(se::Stream* stream, 317 const se::DeviceMemoryBase& source, 318 int64_t size, void* destination); 319 320 // Transfer a memory block of the given size from 'source' buffer to the given 321 // destination of the device. 322 // 323 // size is the size to transfer from source in bytes. 324 virtual Status TransferBufferToDevice(se::Stream* stream, int64_t size, 325 const void* source, 326 se::DeviceMemoryBase* destination); 327 328 private: 329 // The mutex that guards the platform-to-transfer manager map. 330 static absl::Mutex platform_transfer_manager_mutex_; 331 332 // State kept for each kind of TransferManager. Registration functions 333 // set up creation_function, and then we use that to lazily create 334 // "manager" the first time GetForPlatform is invoked for a particular id. 335 struct State { 336 std::unique_ptr<TransferManager> manager; 337 TransferManagerCreationFunction creation_function = nullptr; 338 }; 339 340 // Map from platform kind to transfer manager singleton. 341 static absl::flat_hash_map<se::Platform::Id, State>* 342 GetPlatformTransferManagers(); 343 }; 344 345 } // namespace xla 346 347 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_ 348