xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/padding_fifo_queue.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/data_flow_ops.cc.
17 
18 #include "tensorflow/core/kernels/padding_fifo_queue.h"
19 
20 #include <deque>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/kernels/queue_base.h"
29 #include "tensorflow/core/lib/core/errors.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/batch_util.h"
34 
35 namespace tensorflow {
36 
PaddingFIFOQueue(int capacity,const DataTypeVector & component_dtypes,const std::vector<PartialTensorShape> & component_shapes,const string & name)37 PaddingFIFOQueue::PaddingFIFOQueue(
38     int capacity, const DataTypeVector& component_dtypes,
39     const std::vector<PartialTensorShape>& component_shapes, const string& name)
40     : FIFOQueue(capacity, component_dtypes,
41                 ConvertShapesPartialDimensionsToZero(component_shapes), name),
42       partial_shapes_(component_shapes) {}
43 
Initialize()44 Status PaddingFIFOQueue::Initialize() {
45   Status s = FIFOQueue::Initialize();
46   if (!s.ok()) return s;
47 
48   if (component_dtypes_.size() != partial_shapes_.size()) {
49     return errors::InvalidArgument(
50         "Shapes must be provided for all components, but received ",
51         component_dtypes_.size(), " dtypes and ", partial_shapes_.size(),
52         " shapes.");
53   }
54 
55   return OkStatus();
56 }
57 
58 /* static */
GetElementComponent(const PaddingFIFOQueue::Tuple & tuple,int component,OpKernelContext * ctx,Tensor * out_tensor)59 Status PaddingFIFOQueue::GetElementComponent(
60     const PaddingFIFOQueue::Tuple& tuple, int component, OpKernelContext* ctx,
61     Tensor* out_tensor) {
62   TensorShape element_shape(tuple[component].shape());
63   TF_RETURN_IF_ERROR(
64       ctx->allocate_temp(tuple[component].dtype(), element_shape, out_tensor));
65   *out_tensor = tuple[component];
66   return OkStatus();
67 }
68 
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)69 void PaddingFIFOQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
70                                       bool allow_small_batch,
71                                       CallbackWithTuple callback) {
72   if (num_elements == 0) {
73     Tuple tuple;
74     tuple.reserve(num_components());
75     for (int i = 0; i < num_components(); ++i) {
76       // TODO(josh11b,misard): Switch to allocate_output().
77       // See similar comment in fifo_queue.cc
78       Tensor element;
79       // Here, ManyOutShape returns zeros for undetermined shapes,
80       // which is exactly what we want to use.
81       OP_REQUIRES_OK(ctx, ctx->allocate_temp(component_dtypes_[i],
82                                              ManyOutShape(i, 0), &element));
83       tuple.emplace_back(element);
84     }
85     callback(tuple);
86     return;
87   }
88 
89   CancellationManager* cm = ctx->cancellation_manager();
90   CancellationToken token = cm->get_cancellation_token();
91   bool already_cancelled;
92   {
93     mutex_lock l(mu_);
94     already_cancelled = !cm->RegisterCallback(
95         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
96     if (!already_cancelled) {
97       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
98       dequeue_attempts_.emplace_back(
99           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
100           [callback, allow_small_batch,
101            this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
102             int32_t queue_size = queues_[0].size();
103             if (closed_ && queue_size < attempt->elements_requested) {
104               // If we don't have enough for a full dequeue, we have
105               // to reset the attempt tuple.
106               if (!attempt->tuples.empty()) {
107                 // Restore already-dequeued elements to the front of the queue.
108                 for (int64_t i = attempt->tuples.size() - 1; i >= 0; --i) {
109                   for (int j = 0; j < num_components(); ++j) {
110                     Tensor element;
111                     Status s = GetElementComponent(attempt->tuples[i], j,
112                                                    attempt->context, &element);
113                     if (!s.ok()) {
114                       attempt->context->SetStatus(
115                           errors::DataLoss("Failed to restore element from "
116                                            "partially-dequeued batch "
117                                            "to PaddingFIFOQueue: ",
118                                            s.error_message()));
119                     }
120                     queues_[j].push_front(element);
121                   }
122                 }
123               }
124               if (allow_small_batch && !queues_[0].empty()) {
125                 // Request all remaining elements in the queue.
126                 queue_size = queues_[0].size();
127                 attempt->tuples.clear();
128                 attempt->elements_requested = queue_size;
129               } else {
130                 if (allow_small_batch) {
131                   // There may be some enqueue attempts containing
132                   // values.  If so, we'll yield and wait for them
133                   // to add elements to the queue.
134                   if (!enqueue_attempts_.empty()) return kProgress;
135                 }
136                 if (attempt->context->status().ok()) {
137                   attempt->context->SetStatus(errors::OutOfRange(
138                       "PaddingFIFOQueue '", name_, "' is closed and has ",
139                       "insufficient elements (requested ",
140                       attempt->elements_requested, ", current size ",
141                       queue_size, ")"));
142                 }
143                 return kComplete;
144               }
145             }
146 
147             RunResult result = kNoProgress;
148             for (; queue_size > 0; --queue_size) {
149               result = kProgress;
150               Tuple tuple;
151               DequeueLocked(attempt->context, &tuple);
152               attempt->tuples.push_back(tuple);
153               tuple.clear();
154               --attempt->elements_requested;
155 
156               if (attempt->elements_requested == 0) {
157                 // Finished.  Allocate attempt->tuple and
158                 // copy from attempt->tuples to attempt->tuple.
159                 attempt->tuple.reserve(num_components());
160                 std::vector<Tuple>& tuples = attempt->tuples;
161 
162                 std::vector<bool> dynamic_shape;
163                 const int64_t batch_size = tuples.size();
164 
165                 for (int i = 0; i < num_components(); ++i) {
166                   const PartialTensorShape partial_shape =
167                       PartialTensorShape({batch_size})
168                           .Concatenate(partial_shapes_[i]);
169                   TensorShape shape({batch_size});
170 
171                   for (int j = 0; j < partial_shape.dims() - 1; ++j) {
172                     if (partial_shape.dim_size(j + 1) > -1) {
173                       shape.AddDim(partial_shape.dim_size(j + 1));
174                     } else {
175                       // Expand sizes to match.
176                       int64_t max_val = 0;
177                       for (const Tuple& t : tuples) {
178                         max_val = std::max(max_val, t[i].shape().dim_size(j));
179                       }
180                       shape.AddDim(max_val);
181                     }
182                   }
183 
184                   Tensor element;
185                   attempt->context->SetStatus(attempt->context->allocate_temp(
186                       component_dtypes_[i], shape, &element));
187                   if (!attempt->context->status().ok()) return kComplete;
188 
189                   bool has_dynamic_shape = !partial_shape.IsFullyDefined();
190                   if (has_dynamic_shape) {
191                     // Set all values to zero because not all values
192                     // will get written over.
193                     attempt->context->SetStatus(SetElementZero(&element));
194                     if (!attempt->context->status().ok()) return kComplete;
195                   }
196 
197                   dynamic_shape.push_back(has_dynamic_shape);
198                   attempt->tuple.emplace_back(element);
199                 }
200 
201                 for (size_t index = 0; index < tuples.size(); ++index) {
202                   for (int i = 0; i < num_components(); ++i) {
203                     if (dynamic_shape[i]) {
204                       // Slightly slower copy operation
205                       attempt->context->SetStatus(CopyElementToLargerSlice(
206                           tuples[index][i], &attempt->tuple[i], index));
207                     } else {
208                       attempt->context->SetStatus(
209                           batch_util::CopyElementToSlice(
210                               std::move(tuples[index][i]), &attempt->tuple[i],
211                               index));
212                     }
213                     if (!attempt->context->status().ok()) return kComplete;
214                   }
215                 }
216                 tuple = attempt->tuple;
217                 attempt->tuples.clear();
218                 attempt->done_callback = [callback, tuple]() {
219                   callback(tuple);
220                 };
221                 return kComplete;
222               }
223             }
224             return result;
225           });
226     }
227   }
228   if (!already_cancelled) {
229     FlushUnlocked();
230   } else {
231     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
232     callback(Tuple());
233   }
234 }
235 
ValidateTuple(const Tuple & tuple)236 Status PaddingFIFOQueue::ValidateTuple(const Tuple& tuple) {
237   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
238   for (size_t i = 0; i < tuple.size(); ++i) {
239     if (!partial_shapes_[i].IsCompatibleWith(tuple[i].shape())) {
240       return errors::InvalidArgument("Shape mismatch in tuple component ", i,
241                                      ". Expected ",
242                                      partial_shapes_[i].DebugString(), ", got ",
243                                      tuple[i].shape().DebugString());
244     }
245   }
246   return OkStatus();
247 }
248 
ValidateManyTuple(const Tuple & tuple)249 Status PaddingFIFOQueue::ValidateManyTuple(const Tuple& tuple) {
250   TF_RETURN_IF_ERROR(ValidateTupleCommon(tuple));
251   const int64_t batch_size = tuple[0].dim_size(0);
252   for (size_t i = 0; i < tuple.size(); ++i) {
253     // Expected shape is [batch_size] + partial_shapes_[i]
254     const PartialTensorShape expected_shape =
255         PartialTensorShape({batch_size}).Concatenate(partial_shapes_[i]);
256     if (!expected_shape.IsCompatibleWith(tuple[i].shape())) {
257       return errors::InvalidArgument("Shape mismatch in tuple component ", i,
258                                      ". Expected ",
259                                      expected_shape.DebugString(), ", got ",
260                                      tuple[i].shape().DebugString());
261     }
262   }
263   return OkStatus();
264 }
265 
CompatibleNodeDefShapes(const NodeDef & node_def) const266 Status PaddingFIFOQueue::CompatibleNodeDefShapes(
267     const NodeDef& node_def) const {
268   std::vector<PartialTensorShape> requested_shapes;
269   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
270   if (!PartialTensorShapeUtils::AreCompatible(requested_shapes,
271                                               partial_shapes_)) {
272     return errors::InvalidArgument(
273         "Shared queue '", name_, "' has component shapes ",
274         PartialTensorShapeUtils::PartialShapeListString(partial_shapes_),
275         " but requested component shapes were ",
276         PartialTensorShapeUtils::PartialShapeListString(requested_shapes));
277   } else {
278     return OkStatus();
279   }
280 }
281 
MatchesNodeDef(const NodeDef & node_def)282 Status PaddingFIFOQueue::MatchesNodeDef(const NodeDef& node_def) {
283   if (!MatchesNodeDefOp(node_def, "PaddingFIFOQueue").ok() &&
284       !MatchesNodeDefOp(node_def, "PaddingFIFOQueueV2").ok()) {
285     return errors::InvalidArgument("Expected PaddingFIFOQueue, found ",
286                                    node_def.op());
287   }
288   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
289   TF_RETURN_IF_ERROR(MatchesNodeDefTypes(node_def));
290   TF_RETURN_IF_ERROR(CompatibleNodeDefShapes(node_def));
291   return OkStatus();
292 }
293 
ValidateElementToLargerSlice(const Tensor & element,Tensor * parent)294 static Status ValidateElementToLargerSlice(const Tensor& element,
295                                            Tensor* parent) {
296   DCHECK_NE(parent->dim_size(0), 0);
297   if (element.NumElements() > (parent->NumElements() / parent->dim_size(0))) {
298     TensorShape chip_shape = parent->shape();
299     chip_shape.RemoveDim(0);
300     return errors::Internal(
301         "HandleElementToLargerSlice Cannot copy slice: number of entries in "
302         "element is greater than number of elements in parent slice.  ",
303         "Shapes are: [element]: ", element.shape().DebugString(),
304         ", [parent slice]: ", chip_shape.DebugString());
305   }
306   return OkStatus();
307 }
308 
309 template <typename T, int NDIMS>
HandleElementToLargerSlice(const Tensor & element,Tensor * parent,int index)310 Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
311                                   int index) {
312   Status s = ValidateElementToLargerSlice(element, parent);
313   if (!s.ok()) {
314     return s;
315   }
316   if (element.NumElements() == 0) {
317     return OkStatus();
318   }
319   auto element_t = element.tensor<T, NDIMS>();
320   auto parent_t = parent->tensor<T, NDIMS + 1>();
321   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
322   slice_indices[0] = index;
323   Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_size;
324   slice_size[0] = 1;
325   for (size_t i = 1; i < slice_size.size(); ++i) {
326     slice_size[i] = element_t.dimension(i - 1);
327   }
328   parent_t.slice(slice_indices, slice_size) = element_t.reshape(slice_size);
329   return OkStatus();
330 }
331 
332 namespace {
333 
334 template <int NDIMS>
HandleElementToLargerSliceWithRank(const Tensor & element,Tensor * parent,int index)335 Status HandleElementToLargerSliceWithRank(const Tensor& element, Tensor* parent,
336                                           int index) {
337 #define HANDLE_TYPE(T)                                                   \
338   case DataTypeToEnum<T>::value: {                                       \
339     return HandleElementToLargerSlice<T, NDIMS>(element, parent, index); \
340   }
341 
342   switch (element.dtype()) {
343     TF_CALL_ALL_TYPES(HANDLE_TYPE);
344 #undef HANDLE_TYPE
345     default:
346       return errors::Unimplemented(
347           "HandleElementToLargerSliceWithRank Unhandled data type: ",
348           DataTypeString(element.dtype()));
349   }
350 }
351 
352 }  // namespace
353 
CopyElementToLargerSlice(const Tensor & element,Tensor * parent,int index)354 Status PaddingFIFOQueue::CopyElementToLargerSlice(const Tensor& element,
355                                                   Tensor* parent, int index) {
356   if (parent->dims() != element.dims() + 1) {
357     return errors::Internal(
358         "Mismatched ranks.  Element's rank is: ", element.dims(),
359         " but element is meant to be a slice in output Tensor having rank: ",
360         parent->dims(), " (should be: ", element.dims() + 1, ")");
361   }
362 
363 #define HANDLE_DIMS(NDIMS)                                                  \
364   case NDIMS: {                                                             \
365     TF_RETURN_IF_ERROR(                                                     \
366         HandleElementToLargerSliceWithRank<NDIMS>(element, parent, index)); \
367     return OkStatus();                                                      \
368   }
369 
370   switch (element.dims()) {
371     HANDLE_DIMS(0);
372     HANDLE_DIMS(1);
373     HANDLE_DIMS(2);
374     HANDLE_DIMS(3);
375     HANDLE_DIMS(4);
376 #undef HANDLE_DIMS
377     default:
378       return errors::Unimplemented("CopyElementToLargerSlice Unhandled rank: ",
379                                    element.dims());
380   }
381 }
382 
383 // Static method
SetElementZero(Tensor * element)384 Status PaddingFIFOQueue::SetElementZero(Tensor* element) {
385 #define HANDLE_TYPE(T)                                \
386   if (element->dtype() == DataTypeToEnum<T>::value) { \
387     element->flat<T>().setConstant(T());              \
388     return OkStatus();                                \
389   }
390   TF_CALL_ALL_TYPES(HANDLE_TYPE);
391 #undef HANDLE_TYPE
392   return errors::Unimplemented("SetElementZero Unhandled data type: ",
393                                DataTypeString(element->dtype()));
394 }
395 
ConvertShapesPartialDimensionsToZero(const gtl::ArraySlice<PartialTensorShape> & partial_shapes)396 std::vector<TensorShape> PaddingFIFOQueue::ConvertShapesPartialDimensionsToZero(
397     const gtl::ArraySlice<PartialTensorShape>& partial_shapes) {
398   std::vector<TensorShape> shapes(partial_shapes.size());
399   for (size_t i = 0; i < shapes.size(); ++i) {
400     const PartialTensorShape& partial = partial_shapes[i];
401     TensorShape& shape = shapes[i];
402     for (int64_t s : partial.dim_sizes()) shape.AddDim(s < 0 ? 0 : s);
403   }
404   return shapes;
405 }
406 
407 }  // namespace tensorflow
408