xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/cpu_xfeed.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/cpu/cpu_xfeed.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/base/casts.h"
24 #include "absl/cleanup/cleanup.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
28 #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
29 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/notification.h"
39 
40 namespace xla {
41 namespace {
42 
43 class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer {
44  public:
CpuInfeedBuffer(int32_t length)45   explicit CpuInfeedBuffer(int32_t length)
46       : length_(length), buffer_(new char[length]) {}
~CpuInfeedBuffer()47   ~CpuInfeedBuffer() override { delete[] buffer_; }
48 
length()49   int32_t length() override { return length_; }
data()50   void* data() override { return buffer_; }
Done(StatusOr<Shape>)51   void Done(StatusOr<Shape> /*shape*/) override { delete this; }
52 
53  private:
54   int32_t length_;
55   char* buffer_;
56 };
57 
58 class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer {
59  public:
CpuOutfeedBuffer(void * destination,int32_t length)60   CpuOutfeedBuffer(void* destination, int32_t length)
61       : destination_(destination), length_(length) {}
62 
WaitForNotification()63   StatusOr<Shape> WaitForNotification() {
64     done_.WaitForNotification();
65     return status_;
66   }
67 
length()68   int32_t length() override { return length_; }
data()69   void* data() override { return destination_; }
Done(StatusOr<Shape> shape)70   void Done(StatusOr<Shape> shape) override {
71     status_ = std::move(shape);
72     done_.Notify();
73   }
74 
75  private:
76   void* destination_;
77   int32_t length_;
78   StatusOr<Shape> status_;
79   tensorflow::Notification done_;
80 };
81 
82 // Transfers infeed data to device. InfeedBuffer->Done() must be called to
83 // clean up the memory allocated for InfeedBuffer.
TransferBufferToInfeedInternal(int64_t size,const void * source)84 StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal(
85     int64_t size, const void* source) {
86   if (size > std::numeric_limits<int32_t>::max()) {
87     return InvalidArgument("CPU infeed of %d bytes exceeds maximum of %d bytes",
88                            size, std::numeric_limits<int32_t>::max());
89   }
90 
91   if (size <= 0) {
92     return InvalidArgument("Infeed shape must have positive size; got %d",
93                            size);
94   }
95 
96   auto size_32 = static_cast<int32_t>(size);
97   auto queued_buffer = new CpuInfeedBuffer(size_32);
98   std::memcpy(queued_buffer->data(), source, size);
99 
100   return queued_buffer;
101 }
102 
TransferBufferToInfeed(int device_ordinal,int64_t size,const void * source)103 Status TransferBufferToInfeed(int device_ordinal, int64_t size,
104                               const void* source) {
105   TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
106                       TransferBufferToInfeedInternal(size, source));
107 
108   cpu::runtime::XfeedManager* xfeed_manager =
109       cpu::runtime::GetXfeedManager(device_ordinal);
110   xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer});
111 
112   return OkStatus();
113 }
114 
TransferBuffersFromOutfeedInternal(int device_ordinal,absl::Span<const std::pair<void *,int64_t>> buffer_data,bool is_tuple)115 StatusOr<Shape> TransferBuffersFromOutfeedInternal(
116     int device_ordinal, absl::Span<const std::pair<void*, int64_t>> buffer_data,
117     bool is_tuple) {
118   std::vector<std::unique_ptr<CpuOutfeedBuffer>> buffers;
119   for (auto b : buffer_data) {
120     int64_t size = b.second;
121     if (size > std::numeric_limits<int32_t>::max()) {
122       return InvalidArgument("Outfeed shape is too large: needs %d bytes",
123                              size);
124     }
125 
126     if (size < 0) {
127       return InvalidArgument(
128           "Outfeed shape must have non-negative size; got %d", size);
129     }
130 
131     auto size_32 = static_cast<int32_t>(size);
132     VLOG(2)
133         << "Enqueueing outfeed buffer (for the device to populate) of length "
134         << size_32 << "B";
135     buffers.push_back(std::make_unique<CpuOutfeedBuffer>(b.first, size_32));
136   }
137 
138   std::vector<cpu::runtime::XfeedBuffer*> buffer_pointers;
139   buffer_pointers.reserve(buffers.size());
140   for (auto& b : buffers) {
141     buffer_pointers.push_back(b.get());
142   }
143 
144   cpu::runtime::XfeedManager* xfeed_manager =
145       cpu::runtime::GetXfeedManager(device_ordinal);
146   xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers);
147   VLOG(2) << "Waiting for buffer to be notified as populated.";
148   std::vector<Shape> outfed_shapes;
149   outfed_shapes.reserve(buffers.size());
150   for (auto& buffer : buffers) {
151     TF_ASSIGN_OR_RETURN(Shape outfed_shape, buffer->WaitForNotification());
152     outfed_shapes.push_back(std::move(outfed_shape));
153   }
154   if (is_tuple) {
155     return ShapeUtil::MakeTupleShape(outfed_shapes);
156   }
157   TF_RET_CHECK(outfed_shapes.size() == 1);
158   return std::move(outfed_shapes[0]);
159 }
160 
TransferArrayBufferFromOutfeed(int device_ordinal,void * destination,int64_t size_bytes)161 StatusOr<Shape> TransferArrayBufferFromOutfeed(int device_ordinal,
162                                                void* destination,
163                                                int64_t size_bytes) {
164   return TransferBuffersFromOutfeedInternal(
165       device_ordinal, {{destination, size_bytes}}, /*is_tuple=*/false);
166 }
167 
TransferTupleBuffersFromOutfeed(int device_ordinal,absl::Span<const std::pair<void *,int64_t>> buffer_data)168 StatusOr<Shape> TransferTupleBuffersFromOutfeed(
169     int device_ordinal,
170     absl::Span<const std::pair<void*, int64_t>> buffer_data) {
171   return TransferBuffersFromOutfeedInternal(device_ordinal, buffer_data,
172                                             /*is_tuple=*/true);
173 }
174 }  // namespace
175 
TransferLiteralToInfeedOnCpu(int device_ordinal,const LiteralSlice & literal)176 Status TransferLiteralToInfeedOnCpu(int device_ordinal,
177                                     const LiteralSlice& literal) {
178   const Shape& shape = literal.shape();
179   VLOG(2) << "Transferring literal to infeed with shape: "
180           << ShapeUtil::HumanString(shape);
181 
182   if (!shape.IsTuple()) {
183     int64_t size = cpu::runtime::GetByteSizeRequirement(shape, sizeof(void*));
184     return TransferBufferToInfeed(device_ordinal, size, literal.untyped_data());
185   }
186 
187   if (ShapeUtil::IsNestedTuple(shape)) {
188     return Unimplemented(
189         "Infeed with a nested tuple shape is not supported: %s",
190         ShapeUtil::HumanString(literal.shape()));
191   }
192 
193   // For a tuple, we transfer each of its elements to the device and
194   // enqueue the resulting destination device addresses with the
195   // infeed manager.
196   std::vector<cpu::runtime::XfeedBuffer*> buffers;
197   buffers.reserve(ShapeUtil::TupleElementCount(shape));
198   absl::Cleanup cleanup = [&buffers]() {
199     for (cpu::runtime::XfeedBuffer* b : buffers) {
200       b->Done(Cancelled("Failed to infeed buffer to device."));
201     }
202   };
203 
204   for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
205     const Shape& tuple_element_shape = ShapeUtil::GetSubshape(shape, {i});
206     int64_t tuple_element_size = cpu::runtime::GetByteSizeRequirement(
207         tuple_element_shape, sizeof(void*));
208     TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
209                         TransferBufferToInfeedInternal(
210                             tuple_element_size, literal.untyped_data({i})));
211     buffers.push_back(buffer);
212   }
213 
214   cpu::runtime::XfeedManager* xfeed_manager =
215       cpu::runtime::GetXfeedManager(device_ordinal);
216   xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers);
217 
218   std::move(cleanup).Cancel();
219   return OkStatus();
220 }
221 
TransferLiteralFromOutfeedOnCpu(int device_ordinal,MutableBorrowingLiteral literal)222 Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
223                                        MutableBorrowingLiteral literal) {
224   if (!literal.shape().IsTuple()) {
225     int64_t size =
226         cpu::runtime::GetByteSizeRequirement(literal.shape(), sizeof(void*));
227     // Note: OSS build didn't like implicit conversion from
228     // literal.shape().dimensions() to the array slice on 2017-07-10.
229     absl::Span<const int64_t> dimensions(
230         absl::bit_cast<const int64_t*>(literal.shape().dimensions().data()),
231         literal.shape().dimensions().size());
232     TF_ASSIGN_OR_RETURN(Shape received_shape,
233                         TransferArrayBufferFromOutfeed(
234                             device_ordinal, literal.untyped_data(), size));
235     TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
236         << "Shape received from outfeed "
237         << ShapeUtil::HumanString(received_shape)
238         << " did not match the shape that was requested for outfeed: "
239         << ShapeUtil::HumanString(literal.shape());
240     TF_RET_CHECK(size == cpu::runtime::GetByteSizeRequirement(received_shape,
241                                                               sizeof(void*)));
242     *literal.mutable_shape_do_not_use() = received_shape;
243     return OkStatus();
244   }
245 
246   if (ShapeUtil::IsNestedTuple(literal.shape())) {
247     return Unimplemented(
248         "Nested tuple outfeeds are not yet implemented on CPU.");
249   }
250 
251   std::vector<std::pair<void*, int64_t>> buffer_data;
252   for (int i = 0; i < literal.shape().tuple_shapes_size(); ++i) {
253     const Shape& tuple_element_shape =
254         ShapeUtil::GetTupleElementShape(literal.shape(), i);
255     int64_t size = cpu::runtime::GetByteSizeRequirement(tuple_element_shape,
256                                                         sizeof(void*));
257     buffer_data.push_back({literal.untyped_data({i}), size});
258   }
259 
260   TF_ASSIGN_OR_RETURN(Shape received_shape, TransferTupleBuffersFromOutfeed(
261                                                 device_ordinal, buffer_data));
262 
263   TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
264       << "Shape received from outfeed "
265       << ShapeUtil::HumanString(received_shape)
266       << " did not match the shape that was requested for outfeed: "
267       << ShapeUtil::HumanString(literal.shape());
268   TF_RET_CHECK(
269       cpu::runtime::GetByteSizeRequirement(literal.shape(), sizeof(void*)) ==
270       cpu::runtime::GetByteSizeRequirement(received_shape, sizeof(void*)));
271 
272   TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal.shape()));
273   return OkStatus();
274 }
275 
ReadDynamicShapesOnCpu(ShapedBuffer * device_buffer,Shape * device_shape,HloCostAnalysis::ShapeSizeFunction shape_size_fn)276 Status ReadDynamicShapesOnCpu(
277     ShapedBuffer* device_buffer, Shape* device_shape,
278     HloCostAnalysis::ShapeSizeFunction shape_size_fn) {
279   TF_RET_CHECK(device_shape->is_dynamic());
280   Shape original_device_shape = *device_shape;
281   TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
282       [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
283         const Shape& buffer_shape =
284             ShapeUtil::GetSubshape(*device_shape, index);
285         if (buffer_shape.IsTuple()) {
286           return OkStatus();
287         }
288         Shape& device_sub_shape =
289             *ShapeUtil::GetMutableSubshape(device_shape, index);
290         if (device_sub_shape.is_static()) {
291           return OkStatus();
292         }
293         void* memory = buffer->opaque();
294 
295         // Read the dynamic shape metadata from the device stream.
296         Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
297         const int64_t offset = shape_size_fn(buffer_shape_static);
298         int64_t metadata_size = shape_size_fn(buffer_shape) - offset;
299         if (metadata_size == 0) {
300           return InvalidArgument("Dynamic shape metadata size should not be 0");
301         }
302         auto buffer_8 = static_cast<int8_t*>(memory);
303         auto metadata_buffer = reinterpret_cast<int32_t*>(buffer_8 + offset);
304 
305         // Update shape size from metadata.
306         for (int64_t i = 0; i < device_sub_shape.rank(); ++i) {
307           device_sub_shape.mutable_dimensions()[i] = metadata_buffer[i];
308         }
309         return OkStatus();
310       }));
311   device_shape->clear_dynamic_dimensions();
312 
313   TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
314                                                    original_device_shape));
315   return OkStatus();
316 }
317 }  // namespace xla
318