1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/cpu/cpu_xfeed.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/base/casts.h"
24 #include "absl/cleanup/cleanup.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/literal_util.h"
27 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
28 #include "tensorflow/compiler/xla/service/cpu/xfeed_manager.h"
29 #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
30 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/status_macros.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/notification.h"
39
40 namespace xla {
41 namespace {
42
43 class CpuInfeedBuffer : public cpu::runtime::XfeedBuffer {
44 public:
CpuInfeedBuffer(int32_t length)45 explicit CpuInfeedBuffer(int32_t length)
46 : length_(length), buffer_(new char[length]) {}
~CpuInfeedBuffer()47 ~CpuInfeedBuffer() override { delete[] buffer_; }
48
length()49 int32_t length() override { return length_; }
data()50 void* data() override { return buffer_; }
Done(StatusOr<Shape>)51 void Done(StatusOr<Shape> /*shape*/) override { delete this; }
52
53 private:
54 int32_t length_;
55 char* buffer_;
56 };
57
58 class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer {
59 public:
CpuOutfeedBuffer(void * destination,int32_t length)60 CpuOutfeedBuffer(void* destination, int32_t length)
61 : destination_(destination), length_(length) {}
62
WaitForNotification()63 StatusOr<Shape> WaitForNotification() {
64 done_.WaitForNotification();
65 return status_;
66 }
67
length()68 int32_t length() override { return length_; }
data()69 void* data() override { return destination_; }
Done(StatusOr<Shape> shape)70 void Done(StatusOr<Shape> shape) override {
71 status_ = std::move(shape);
72 done_.Notify();
73 }
74
75 private:
76 void* destination_;
77 int32_t length_;
78 StatusOr<Shape> status_;
79 tensorflow::Notification done_;
80 };
81
82 // Transfers infeed data to device. InfeedBuffer->Done() must be called to
83 // clean up the memory allocated for InfeedBuffer.
TransferBufferToInfeedInternal(int64_t size,const void * source)84 StatusOr<cpu::runtime::XfeedBuffer*> TransferBufferToInfeedInternal(
85 int64_t size, const void* source) {
86 if (size > std::numeric_limits<int32_t>::max()) {
87 return InvalidArgument("CPU infeed of %d bytes exceeds maximum of %d bytes",
88 size, std::numeric_limits<int32_t>::max());
89 }
90
91 if (size <= 0) {
92 return InvalidArgument("Infeed shape must have positive size; got %d",
93 size);
94 }
95
96 auto size_32 = static_cast<int32_t>(size);
97 auto queued_buffer = new CpuInfeedBuffer(size_32);
98 std::memcpy(queued_buffer->data(), source, size);
99
100 return queued_buffer;
101 }
102
TransferBufferToInfeed(int device_ordinal,int64_t size,const void * source)103 Status TransferBufferToInfeed(int device_ordinal, int64_t size,
104 const void* source) {
105 TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
106 TransferBufferToInfeedInternal(size, source));
107
108 cpu::runtime::XfeedManager* xfeed_manager =
109 cpu::runtime::GetXfeedManager(device_ordinal);
110 xfeed_manager->infeed()->EnqueueBuffersAtomically({buffer});
111
112 return OkStatus();
113 }
114
TransferBuffersFromOutfeedInternal(int device_ordinal,absl::Span<const std::pair<void *,int64_t>> buffer_data,bool is_tuple)115 StatusOr<Shape> TransferBuffersFromOutfeedInternal(
116 int device_ordinal, absl::Span<const std::pair<void*, int64_t>> buffer_data,
117 bool is_tuple) {
118 std::vector<std::unique_ptr<CpuOutfeedBuffer>> buffers;
119 for (auto b : buffer_data) {
120 int64_t size = b.second;
121 if (size > std::numeric_limits<int32_t>::max()) {
122 return InvalidArgument("Outfeed shape is too large: needs %d bytes",
123 size);
124 }
125
126 if (size < 0) {
127 return InvalidArgument(
128 "Outfeed shape must have non-negative size; got %d", size);
129 }
130
131 auto size_32 = static_cast<int32_t>(size);
132 VLOG(2)
133 << "Enqueueing outfeed buffer (for the device to populate) of length "
134 << size_32 << "B";
135 buffers.push_back(std::make_unique<CpuOutfeedBuffer>(b.first, size_32));
136 }
137
138 std::vector<cpu::runtime::XfeedBuffer*> buffer_pointers;
139 buffer_pointers.reserve(buffers.size());
140 for (auto& b : buffers) {
141 buffer_pointers.push_back(b.get());
142 }
143
144 cpu::runtime::XfeedManager* xfeed_manager =
145 cpu::runtime::GetXfeedManager(device_ordinal);
146 xfeed_manager->outfeed()->EnqueueBuffersAtomically(buffer_pointers);
147 VLOG(2) << "Waiting for buffer to be notified as populated.";
148 std::vector<Shape> outfed_shapes;
149 outfed_shapes.reserve(buffers.size());
150 for (auto& buffer : buffers) {
151 TF_ASSIGN_OR_RETURN(Shape outfed_shape, buffer->WaitForNotification());
152 outfed_shapes.push_back(std::move(outfed_shape));
153 }
154 if (is_tuple) {
155 return ShapeUtil::MakeTupleShape(outfed_shapes);
156 }
157 TF_RET_CHECK(outfed_shapes.size() == 1);
158 return std::move(outfed_shapes[0]);
159 }
160
TransferArrayBufferFromOutfeed(int device_ordinal,void * destination,int64_t size_bytes)161 StatusOr<Shape> TransferArrayBufferFromOutfeed(int device_ordinal,
162 void* destination,
163 int64_t size_bytes) {
164 return TransferBuffersFromOutfeedInternal(
165 device_ordinal, {{destination, size_bytes}}, /*is_tuple=*/false);
166 }
167
TransferTupleBuffersFromOutfeed(int device_ordinal,absl::Span<const std::pair<void *,int64_t>> buffer_data)168 StatusOr<Shape> TransferTupleBuffersFromOutfeed(
169 int device_ordinal,
170 absl::Span<const std::pair<void*, int64_t>> buffer_data) {
171 return TransferBuffersFromOutfeedInternal(device_ordinal, buffer_data,
172 /*is_tuple=*/true);
173 }
174 } // namespace
175
TransferLiteralToInfeedOnCpu(int device_ordinal,const LiteralSlice & literal)176 Status TransferLiteralToInfeedOnCpu(int device_ordinal,
177 const LiteralSlice& literal) {
178 const Shape& shape = literal.shape();
179 VLOG(2) << "Transferring literal to infeed with shape: "
180 << ShapeUtil::HumanString(shape);
181
182 if (!shape.IsTuple()) {
183 int64_t size = cpu::runtime::GetByteSizeRequirement(shape, sizeof(void*));
184 return TransferBufferToInfeed(device_ordinal, size, literal.untyped_data());
185 }
186
187 if (ShapeUtil::IsNestedTuple(shape)) {
188 return Unimplemented(
189 "Infeed with a nested tuple shape is not supported: %s",
190 ShapeUtil::HumanString(literal.shape()));
191 }
192
193 // For a tuple, we transfer each of its elements to the device and
194 // enqueue the resulting destination device addresses with the
195 // infeed manager.
196 std::vector<cpu::runtime::XfeedBuffer*> buffers;
197 buffers.reserve(ShapeUtil::TupleElementCount(shape));
198 absl::Cleanup cleanup = [&buffers]() {
199 for (cpu::runtime::XfeedBuffer* b : buffers) {
200 b->Done(Cancelled("Failed to infeed buffer to device."));
201 }
202 };
203
204 for (int64_t i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
205 const Shape& tuple_element_shape = ShapeUtil::GetSubshape(shape, {i});
206 int64_t tuple_element_size = cpu::runtime::GetByteSizeRequirement(
207 tuple_element_shape, sizeof(void*));
208 TF_ASSIGN_OR_RETURN(cpu::runtime::XfeedBuffer * buffer,
209 TransferBufferToInfeedInternal(
210 tuple_element_size, literal.untyped_data({i})));
211 buffers.push_back(buffer);
212 }
213
214 cpu::runtime::XfeedManager* xfeed_manager =
215 cpu::runtime::GetXfeedManager(device_ordinal);
216 xfeed_manager->infeed()->EnqueueBuffersAtomically(buffers);
217
218 std::move(cleanup).Cancel();
219 return OkStatus();
220 }
221
TransferLiteralFromOutfeedOnCpu(int device_ordinal,MutableBorrowingLiteral literal)222 Status TransferLiteralFromOutfeedOnCpu(int device_ordinal,
223 MutableBorrowingLiteral literal) {
224 if (!literal.shape().IsTuple()) {
225 int64_t size =
226 cpu::runtime::GetByteSizeRequirement(literal.shape(), sizeof(void*));
227 // Note: OSS build didn't like implicit conversion from
228 // literal.shape().dimensions() to the array slice on 2017-07-10.
229 absl::Span<const int64_t> dimensions(
230 absl::bit_cast<const int64_t*>(literal.shape().dimensions().data()),
231 literal.shape().dimensions().size());
232 TF_ASSIGN_OR_RETURN(Shape received_shape,
233 TransferArrayBufferFromOutfeed(
234 device_ordinal, literal.untyped_data(), size));
235 TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
236 << "Shape received from outfeed "
237 << ShapeUtil::HumanString(received_shape)
238 << " did not match the shape that was requested for outfeed: "
239 << ShapeUtil::HumanString(literal.shape());
240 TF_RET_CHECK(size == cpu::runtime::GetByteSizeRequirement(received_shape,
241 sizeof(void*)));
242 *literal.mutable_shape_do_not_use() = received_shape;
243 return OkStatus();
244 }
245
246 if (ShapeUtil::IsNestedTuple(literal.shape())) {
247 return Unimplemented(
248 "Nested tuple outfeeds are not yet implemented on CPU.");
249 }
250
251 std::vector<std::pair<void*, int64_t>> buffer_data;
252 for (int i = 0; i < literal.shape().tuple_shapes_size(); ++i) {
253 const Shape& tuple_element_shape =
254 ShapeUtil::GetTupleElementShape(literal.shape(), i);
255 int64_t size = cpu::runtime::GetByteSizeRequirement(tuple_element_shape,
256 sizeof(void*));
257 buffer_data.push_back({literal.untyped_data({i}), size});
258 }
259
260 TF_ASSIGN_OR_RETURN(Shape received_shape, TransferTupleBuffersFromOutfeed(
261 device_ordinal, buffer_data));
262
263 TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal.shape()))
264 << "Shape received from outfeed "
265 << ShapeUtil::HumanString(received_shape)
266 << " did not match the shape that was requested for outfeed: "
267 << ShapeUtil::HumanString(literal.shape());
268 TF_RET_CHECK(
269 cpu::runtime::GetByteSizeRequirement(literal.shape(), sizeof(void*)) ==
270 cpu::runtime::GetByteSizeRequirement(received_shape, sizeof(void*)));
271
272 TF_RET_CHECK(ShapeUtil::Equal(literal.shape(), literal.shape()));
273 return OkStatus();
274 }
275
ReadDynamicShapesOnCpu(ShapedBuffer * device_buffer,Shape * device_shape,HloCostAnalysis::ShapeSizeFunction shape_size_fn)276 Status ReadDynamicShapesOnCpu(
277 ShapedBuffer* device_buffer, Shape* device_shape,
278 HloCostAnalysis::ShapeSizeFunction shape_size_fn) {
279 TF_RET_CHECK(device_shape->is_dynamic());
280 Shape original_device_shape = *device_shape;
281 TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
282 [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
283 const Shape& buffer_shape =
284 ShapeUtil::GetSubshape(*device_shape, index);
285 if (buffer_shape.IsTuple()) {
286 return OkStatus();
287 }
288 Shape& device_sub_shape =
289 *ShapeUtil::GetMutableSubshape(device_shape, index);
290 if (device_sub_shape.is_static()) {
291 return OkStatus();
292 }
293 void* memory = buffer->opaque();
294
295 // Read the dynamic shape metadata from the device stream.
296 Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
297 const int64_t offset = shape_size_fn(buffer_shape_static);
298 int64_t metadata_size = shape_size_fn(buffer_shape) - offset;
299 if (metadata_size == 0) {
300 return InvalidArgument("Dynamic shape metadata size should not be 0");
301 }
302 auto buffer_8 = static_cast<int8_t*>(memory);
303 auto metadata_buffer = reinterpret_cast<int32_t*>(buffer_8 + offset);
304
305 // Update shape size from metadata.
306 for (int64_t i = 0; i < device_sub_shape.rank(); ++i) {
307 device_sub_shape.mutable_dimensions()[i] = metadata_buffer[i];
308 }
309 return OkStatus();
310 }));
311 device_shape->clear_dynamic_dimensions();
312
313 TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
314 original_device_shape));
315 return OkStatus();
316 }
317 } // namespace xla
318