xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/transfer_manager.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
18 
19 #include <map>
20 #include <set>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/executable.h"
27 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/types.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
32 #include "tensorflow/stream_executor/device_memory.h"
33 
34 namespace xla {
35 
36 // The TransferManager interface lets backends provide platform-specific
37 // mechanisms for constructing literals from given device memory handles.
38 // This lets each platform customize how literals are transferred to/from the
39 // device in terms of padding, leading dimension, etc.
40 class TransferManager {
41  public:
~TransferManager()42   virtual ~TransferManager() {}
43 
44   // Returns the ID of the platform that this transfer manager acts on.
45   virtual se::Platform::Id PlatformId() const = 0;
46 
47   // Returns the shape of the on-device representation for the given shape on
48   // the host. This is intended for use with ShapedBuffer where buffers are
49   // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user
50   // needing to consider device-specific behaviors.
HostShapeToDeviceShape(const Shape & host_shape)51   virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const {
52     // Strips off any preexisting tiling or memory space information.
53     // TODO(phawkins): fix clients not to including tiling or memory space
54     // information in shapes passed to this function and turn this into an
55     // assertion.
56     return ShapeUtil::DeviceShapeToHostShape(host_shape);
57   }
58 
59   // Base class for specifying platform specific transfer metadata that can be
60   // used to tell the underlying implementation to perform specific optimization
61   // to a transfer. Actual metadata passed to supported transfer methods should
62   // subclass this class.
63   class TransferMetadata {
64    public:
65     virtual ~TransferMetadata() = 0;
66   };
67   // Returns a literal containing the data held in the given ShapedBuffer
68   // using the provided executor. This operation is performed synchronously
69   // without waiting for any other operation on a stream to complete.
70   //
71   // This function should be avoided in favor of the asynchronous version below.
72   //
73   // Optionally caller can specify platform-specific transfer metadata that
74   // tells the actual implementation to do something special.
75   virtual StatusOr<Literal> TransferLiteralFromDevice(
76       se::Stream* stream, const ShapedBuffer& device_buffer,
77       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer)78   StatusOr<Literal> TransferLiteralFromDevice(
79       se::Stream* stream, const ShapedBuffer& device_buffer) {
80     return TransferLiteralFromDevice(stream, device_buffer, nullptr);
81   }
82   virtual Status TransferLiteralFromDevice(
83       se::Stream* stream, const ShapedBuffer& device_buffer,
84       const MutableBorrowingLiteral& literal,
85       const TransferMetadata* transfer_metadata);
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,const MutableBorrowingLiteral & literal)86   Status TransferLiteralFromDevice(se::Stream* stream,
87                                    const ShapedBuffer& device_buffer,
88                                    const MutableBorrowingLiteral& literal) {
89     return TransferLiteralFromDevice(stream, device_buffer, literal, nullptr);
90   }
91 
92   // Begins transferring a literal containing the data held in the given
93   // ShapedBuffer using the provided executor.
94   //
95   // This operation is performed asynchronously on the given stream. It returns
96   // once the transfer is enqueued. 'done' is invoked with the result when
97   // complete.
98   //
99   // device_buffer is copied by reference and must live at least until done() is
100   // invoked.
101   //
102   // Optionally caller can specify platform-specific transfer metadata that
103   // tells the actual implementation to do something special.
104   virtual void TransferLiteralFromDevice(
105       se::Stream* stream, const ShapedBuffer& device_buffer,
106       MutableBorrowingLiteral literal, std::function<void(Status)> done,
107       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralFromDevice(se::Stream * stream,const ShapedBuffer & device_buffer,MutableBorrowingLiteral literal,std::function<void (Status)> done)108   void TransferLiteralFromDevice(se::Stream* stream,
109                                  const ShapedBuffer& device_buffer,
110                                  MutableBorrowingLiteral literal,
111                                  std::function<void(Status)> done) {
112     return TransferLiteralFromDevice(stream, device_buffer, literal, done,
113                                      nullptr);
114   }
115 
116   // Transfers the given literal into the previously allocated device memory
117   // represented by the given ShapedBuffer using the given executor. The shape
118   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
119   // but need not have the same layout.
120   //
121   // This operation is performed synchronously without waiting for any other
122   // operation on a stream to complete. This function should be avoided in favor
123   // of the asynchronous version below.
124   //
125   // Optionally caller can specify platform-specific transfer metadata that
126   // tells the actual implementation to do something special.
127   virtual Status TransferLiteralToDevice(
128       se::Stream* stream, const LiteralSlice& literal,
129       const ShapedBuffer& device_buffer,
130       const TransferMetadata* transfer_metadata);
TransferLiteralToDevice(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)131   Status TransferLiteralToDevice(se::Stream* stream,
132                                  const LiteralSlice& literal,
133                                  const ShapedBuffer& device_buffer) {
134     return TransferLiteralToDevice(stream, literal, device_buffer, nullptr);
135   }
136 
137   // Transfers the given literal into the previously allocated device memory
138   // represented by the given ShapedBuffer using the given executor. The shape
139   // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
140   // but need not have the same layout.
141   //
142   // This operation is performed asynchronously on the given stream. It returns
143   // once the transfer is enqueued, and may return before the transfer has
144   // completed.
145   //
146   // The caller may free the data structures 'literal' and 'device_buffer'
147   // immediately after this function returns, however their constituent buffers
148   // on both host and device must remain valid until the enqueued transfer has
149   // completed on 'stream'.
150   //
151   // Optionally caller can specify platform-specific transfer metadata that
152   // tells the actual implementation to do something special.
153   virtual Status TransferLiteralToDeviceAsync(
154       se::Stream* stream, const LiteralSlice& literal,
155       const ShapedBuffer& device_buffer,
156       const TransferMetadata* transfer_metadata) = 0;
TransferLiteralToDeviceAsync(se::Stream * stream,const LiteralSlice & literal,const ShapedBuffer & device_buffer)157   Status TransferLiteralToDeviceAsync(se::Stream* stream,
158                                       const LiteralSlice& literal,
159                                       const ShapedBuffer& device_buffer) {
160     return TransferLiteralToDeviceAsync(stream, literal, device_buffer,
161                                         nullptr);
162   }
163 
164   // Convenience methods for transferring an array to or from the device at a
165   // known address. This avoids having to construct a ShapedBuffer just to
166   // transfer an array at a known address.
167   //
168   // Optionally caller can specify platform-specific transfer metadata that
169   // tells the actual implementation to do something special.
170   Status TransferArrayToDevice(
171       se::Stream* stream, const LiteralSlice& literal,
172       const se::DeviceMemoryBase& dest,
173       const TransferMetadata* transfer_metadata = nullptr);
174   void TransferArrayFromDevice(
175       se::Stream* stream, const Shape& shape,
176       const se::DeviceMemoryBase& source,
177       const MutableBorrowingLiteral& literal, std::function<void(Status)> done,
178       const TransferMetadata* transfer_metadata = nullptr);
179 
180   Status TransferArrayToDeviceAsync(
181       se::Stream* stream, const LiteralSlice& literal,
182       const se::DeviceMemoryBase& dest,
183       const TransferMetadata* transfer_metadata = nullptr);
184   StatusOr<Literal> TransferArrayFromDevice(
185       se::Stream* stream, const Shape& shape,
186       const se::DeviceMemoryBase& source,
187       const TransferMetadata* transfer_metadata = nullptr);
188 
189   // Read from a device buffer and update the dynamic dimension sizes of
190   // `host_shape` and `device_shape`. The function takes in bounded dynamic
191   // shapes, and returns static shapes with dynamic shapes updated.
192   // The shape of the buffer also have to be compatible with the host shape and
193   // device shape.
194   virtual Status ReadDynamicShapes(se::Stream* stream,
195                                    ShapedBuffer* device_buffer,
196                                    Shape* device_shape);
197 
198   // Transfers the given literal into the Infeed interface of the device,
199   // using the given executor.
200   virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
201                                          const LiteralSlice& literal) = 0;
202 
203   // Transfers the given literal from the Outfeed interface of the device,
204   // using the given executor. The shape and layout are determined by the
205   // shape and layout of `literal`.
206   virtual Status TransferLiteralFromOutfeed(
207       se::StreamExecutor* executor, MutableBorrowingLiteral literal) = 0;
208 
209   // Resets the devices associated with this transfer manager.
210   virtual Status ResetDevices(
211       absl::Span<se::StreamExecutor* const> executor) = 0;
212 
213   // Given an allocated ShapedBuffer, constructs the tuple index table(s) in
214   // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the
215   // ShapedBuffer is array-shaped this method does nothing.
216   Status WriteTupleIndexTables(se::Stream* stream,
217                                const ShapedBuffer& device_buffer);
218   Status WriteTupleIndexTablesAsync(se::Stream* stream,
219                                     const ShapedBuffer& device_buffer);
220 
221   // Writes a tuple index buffer for the root of 'device_buffer', which must
222   // be a tuple. Unlike WriteTupleIndexTables, only writes the root buffer,
223   // rather than writing all subbuffers. This method is always asynchronous.
224   Status WriteRootTupleIndexTable(se::Stream* stream,
225                                   const ShapedBuffer& device_buffer);
226   Status WriteRootTupleIndexTable(
227       se::Stream* stream,
228       const ShapeTree<MaybeOwningDeviceMemory>& buffer_tree);
229 
230   // Determines the byte size requirement for the given shape on the underlying
231   // architecture. This will be used to allocate an appropriately sized memory
232   // region for a host-to-device transfer.
233   virtual int64_t GetByteSizeRequirement(const Shape& shape) const = 0;
234 
235   // Chooses a compact layout for 'shape', ignoring any existing layout on
236   // 'shape'. What "reasonable" means is left up to the backend. The
237   // intended use case is to choose a layout that avoids excessive padding on
238   // devices that have tiled memory architectures.
239   // The default implementation always picks a default (major-to-minor) layout.
240   // Fails if 'shape' cannot be represented by the device.
241   virtual StatusOr<Shape> ChooseCompactLayoutForShape(
242       const Shape& host_shape) const;
243 
244   // For the given shape, chooses a layout for infeed. The returned shape
245   // has the same dimensions as the original shape, and only the layout is
246   // changed.
247   virtual Shape ChooseGoodInfeedLayout(const Shape& shape) const;
248 
249   typedef std::function<Shape(const Shape&)> DeviceShapeRepresentationFn;
250 
251   // Allocates a ScopedShapedBuffer which can hold data with the given on-host
252   // shape. The on-device shape may be different as indicated by
253   // HostShapeToDeviceShape.
254   StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
255       const Shape& on_host_shape, se::DeviceMemoryAllocator* allocator,
256       int device_ordinal,
257       DeviceShapeRepresentationFn shape_representation_fn = nullptr);
258 
259   // The given ShapedBuffer holds a handle to allocated memory, but it is not
260   // in the general case legal to immediately copy or access that allocated
261   // memory because queued operations on the device may alias that memory.
262   // Memory ordering is enforced by the Stream's happens-before relationship
263   // which allows eager deallocation and reallocation of buffers host-side even
264   // if the device hasn't finished with them.
265   //
266   // In certain cases, it can be known that a ShapedBuffer does not have any
267   // conflicting accesses on the device and thus is eligible to be accessed at
268   // any time from the host.
269   //
270   // This function returns true if device_buffer can be accessed immediately
271   // without waiting for the Stream's previously enqueued items. This only
272   // returns true if all subbuffers in device_buffer can be accessed
273   // immediately.
CanShapedBufferBeAccessedNow(se::StreamExecutor * executor,const ShapedBuffer & device_buffer)274   virtual bool CanShapedBufferBeAccessedNow(
275       se::StreamExecutor* executor, const ShapedBuffer& device_buffer) const {
276     return false;
277   }
278 
279   // Equivalent to CanShapedBufferBeAccessedNow but for a single device buffer.
CanBufferBeAccessedNow(se::StreamExecutor * executor,const se::DeviceMemoryBase & device_buffer)280   virtual bool CanBufferBeAccessedNow(
281       se::StreamExecutor* executor,
282       const se::DeviceMemoryBase& device_buffer) const {
283     return false;
284   }
285 
286   /////
287   // The TransferManager class also serves as a point to register objects for
288   // the various platforms.
289 
290   // Registers the TransferManager singleton for the platform kind. This is
291   // assumed to be a singleton, so no ownership is transferred.
292   //
293   // Precondition: a platform kind must not be registered more than once.
294   typedef std::unique_ptr<TransferManager> (*TransferManagerCreationFunction)();
295   static void RegisterTransferManager(
296       se::Platform::Id platform_id,
297       TransferManagerCreationFunction transfer_manager);
298 
299   // Returns the transfer manager singleton pointer if it is available for the
300   // given platform, or an error status if it is not.
301   static StatusOr<TransferManager*> GetForPlatform(
302       const se::Platform* platform);
303 
304   // Writes the given device-memory pointers in 'elements' to the given region
305   // to construct a tuple index table in the platform-specific tuple
306   // representation.
307   virtual Status WriteSingleTupleIndexTable(
308       se::Stream* stream, absl::Span<const se::DeviceMemoryBase> elements,
309       const Shape& shape, se::DeviceMemoryBase* region) = 0;
310 
311  protected:
312   // Transfer a memory block of the given size from the device source into the
313   // 'destination' buffer.
314   //
315   // size is the size to transfer to destination in bytes.
316   virtual Status TransferBufferFromDevice(se::Stream* stream,
317                                           const se::DeviceMemoryBase& source,
318                                           int64_t size, void* destination);
319 
320   // Transfer a memory block of the given size from 'source' buffer to the given
321   // destination of the device.
322   //
323   // size is the size to transfer from source in bytes.
324   virtual Status TransferBufferToDevice(se::Stream* stream, int64_t size,
325                                         const void* source,
326                                         se::DeviceMemoryBase* destination);
327 
328  private:
329   // The mutex that guards the platform-to-transfer manager map.
330   static absl::Mutex platform_transfer_manager_mutex_;
331 
332   // State kept for each kind of TransferManager.  Registration functions
333   // set up creation_function, and then we use that to lazily create
334   // "manager" the first time GetForPlatform is invoked for a particular id.
335   struct State {
336     std::unique_ptr<TransferManager> manager;
337     TransferManagerCreationFunction creation_function = nullptr;
338   };
339 
340   // Map from platform kind to transfer manager singleton.
341   static absl::flat_hash_map<se::Platform::Id, State>*
342   GetPlatformTransferManagers();
343 };
344 
345 }  // namespace xla
346 
347 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
348