xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/priority_queue.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // See docs in ../ops/data_flow_ops.cc.
16 
17 #include "tensorflow/core/kernels/priority_queue.h"
18 
19 #include <deque>
20 #include <queue>
21 #include <vector>
22 
23 #include "tensorflow/core/framework/node_def.pb.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/kernels/queue_base.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/gtl/priority_queue_util.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/batch_util.h"
34 
35 namespace tensorflow {
36 
PriorityQueue(int32_t capacity,const DataTypeVector & component_dtypes,const std::vector<TensorShape> & component_shapes,const string & name)37 PriorityQueue::PriorityQueue(int32_t capacity,
38                              const DataTypeVector& component_dtypes,
39                              const std::vector<TensorShape>& component_shapes,
40                              const string& name)
41     : TypedQueue(capacity, component_dtypes, component_shapes, name) {}
42 
Initialize()43 Status PriorityQueue::Initialize() {
44   Status s = TypedQueue::Initialize();
45   if (!s.ok()) return s;
46 
47   mutex_lock lock(mu_);
48   if (component_dtypes_[0] != DT_INT64) {
49     return errors::InvalidArgument(
50         "PriorityQueue priority index component must be type int64, but "
51         "dtype is: ",
52         DataTypeString(component_dtypes_[0]));
53   }
54   if (specified_shapes() && !TensorShapeUtils::IsScalar(component_shapes_[0])) {
55     return errors::InvalidArgument(
56         "PriorityQueue priority index component must be a scalar, but shape "
57         "is: ",
58         component_shapes_[0].DebugString());
59   }
60   return OkStatus();
61 }
62 
DequeueLocked(OpKernelContext * ctx,Tuple * tuple)63 void PriorityQueue::DequeueLocked(OpKernelContext* ctx, Tuple* tuple) {
64   DCHECK_GT(queues_[0].size(), 0);
65   (*tuple).reserve(num_components());
66   for (int i = 0; i < num_components(); ++i) {
67     Tensor tensor = gtl::ConsumeTop(&queues_[i]).second;
68     (*tuple).push_back(tensor);
69   }
70 }
71 
TryEnqueue(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)72 void PriorityQueue::TryEnqueue(const Tuple& tuple, OpKernelContext* ctx,
73                                DoneCallback callback) {
74   CancellationManager* cm = ctx->cancellation_manager();
75   CancellationToken token = cm->get_cancellation_token();
76   bool already_cancelled;
77   {
78     mutex_lock l(mu_);
79     already_cancelled = !cm->RegisterCallback(
80         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
81     if (!already_cancelled) {
82       enqueue_attempts_.emplace_back(
83           1, callback, ctx, cm, token,
84           [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
85             if (closed_) {
86               attempt->context->SetStatus(
87                   errors::Cancelled("PriorityQueue '", name_, "' is closed."));
88               return kComplete;
89             }
90             if (queues_[0].size() < static_cast<size_t>(capacity_)) {
91               if (!TensorShapeUtils::IsScalar(tuple[0].shape())) {
92                 attempt->context->SetStatus(errors::InvalidArgument(
93                     "Expected the priority element to be a scalar, but "
94                     "received shape: ",
95                     tuple[0].shape().DebugString()));
96                 return kComplete;
97               }
98               const int64_t priority = tuple[0].scalar<int64_t>()();
99               for (int i = 0; i < num_components(); ++i) {
100                 queues_[i].emplace(priority, tuple[i]);
101               }
102               return kComplete;
103             } else {
104               return kNoProgress;
105             }
106           });
107     }
108   }
109   if (!already_cancelled) {
110     FlushUnlocked();
111   } else {
112     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
113     callback();
114   }
115 }
116 
117 /* static */
GetElementComponentFromBatch(const PriorityQueue::Tuple & tuple,int index,int component,OpKernelContext * ctx,Tensor * out_element)118 Status PriorityQueue::GetElementComponentFromBatch(
119     const PriorityQueue::Tuple& tuple, int index, int component,
120     OpKernelContext* ctx, Tensor* out_element) {
121   TensorShape element_shape(tuple[component].shape());
122   element_shape.RemoveDim(0);
123   TF_RETURN_IF_ERROR(
124       ctx->allocate_temp(tuple[component].dtype(), element_shape, out_element));
125   TF_RETURN_IF_ERROR(
126       batch_util::CopySliceToElement(tuple[component], out_element, index));
127   return OkStatus();
128 }
129 
TryEnqueueMany(const Tuple & tuple,OpKernelContext * ctx,DoneCallback callback)130 void PriorityQueue::TryEnqueueMany(const Tuple& tuple, OpKernelContext* ctx,
131                                    DoneCallback callback) {
132   const int64_t batch_size = tuple[0].dim_size(0);
133   if (batch_size == 0) {
134     callback();
135     return;
136   }
137 
138   CancellationManager* cm = ctx->cancellation_manager();
139   CancellationToken token = cm->get_cancellation_token();
140   bool already_cancelled;
141   {
142     mutex_lock l(mu_);
143     already_cancelled = !cm->RegisterCallback(
144         token, [this, cm, token]() { Cancel(kEnqueue, cm, token); });
145     if (!already_cancelled) {
146       enqueue_attempts_.emplace_back(
147           batch_size, callback, ctx, cm, token,
148           [tuple, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
149             if (closed_) {
150               attempt->context->SetStatus(
151                   errors::Cancelled("PriorityQueue '", name_, "' is closed."));
152               return kComplete;
153             }
154             RunResult result = kNoProgress;
155             while (queues_[0].size() < static_cast<size_t>(capacity_)) {
156               result = kProgress;
157               const int index =
158                   tuple[0].dim_size(0) - attempt->elements_requested;
159 
160               Tensor priority_element;
161               attempt->context->SetStatus(GetElementComponentFromBatch(
162                   tuple, index, 0, attempt->context, &priority_element));
163               if (!attempt->context->status().ok()) return kComplete;
164               if (!TensorShapeUtils::IsScalar(priority_element.shape())) {
165                 attempt->context->SetStatus(errors::InvalidArgument(
166                     "Expected the priority element to be a scalar, but "
167                     "received shape: ",
168                     priority_element.shape().DebugString()));
169                 return kComplete;
170               }
171               const int64_t priority = priority_element.scalar<int64_t>()();
172               for (int i = 0; i < num_components(); ++i) {
173                 Tensor element;
174                 attempt->context->SetStatus(GetElementComponentFromBatch(
175                     tuple, index, i, attempt->context, &element));
176                 if (!attempt->context->status().ok()) return kComplete;
177                 queues_[i].emplace(priority, element);
178               }
179               --attempt->elements_requested;
180               if (attempt->elements_requested == 0) {
181                 return kComplete;
182               }
183             }
184             return result;
185           });
186     }
187   }
188   if (!already_cancelled) {
189     FlushUnlocked();
190   } else {
191     ctx->SetStatus(errors::Cancelled("Enqueue operation was cancelled"));
192     callback();
193   }
194 }
195 
TryDequeue(OpKernelContext * ctx,CallbackWithTuple callback)196 void PriorityQueue::TryDequeue(OpKernelContext* ctx,
197                                CallbackWithTuple callback) {
198   CancellationManager* cm = ctx->cancellation_manager();
199   CancellationToken token = cm->get_cancellation_token();
200   bool already_cancelled;
201   {
202     mutex_lock l(mu_);
203     already_cancelled = !cm->RegisterCallback(
204         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
205     if (!already_cancelled) {
206       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
207       dequeue_attempts_.emplace_back(
208           1, [callback]() { callback(Tuple()); }, ctx, cm, token,
209           [callback, this](Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
210             const int32_t s = queues_[0].size();
211             if (closed_ && s == 0) {
212               attempt->context->SetStatus(errors::OutOfRange(
213                   "PriorityQueue '", name_, "' is closed and has ",
214                   "insufficient elements (requested ", 1, ", current size ", s,
215                   ")"));
216               return kComplete;
217             }
218             if (s > 0) {
219               Tuple tuple;
220               DequeueLocked(attempt->context, &tuple);
221               attempt->done_callback = [callback, tuple]() { callback(tuple); };
222               return kComplete;
223             } else {
224               return kNoProgress;
225             }
226           });
227     }
228   }
229   if (!already_cancelled) {
230     FlushUnlocked();
231   } else {
232     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
233     callback(Tuple());
234   }
235 }
236 
TryDequeueMany(int num_elements,OpKernelContext * ctx,bool allow_small_batch,CallbackWithTuple callback)237 void PriorityQueue::TryDequeueMany(int num_elements, OpKernelContext* ctx,
238                                    bool allow_small_batch,
239                                    CallbackWithTuple callback) {
240   if (!specified_shapes()) {
241     ctx->SetStatus(
242         errors::InvalidArgument("PriorityQueue's DequeueMany requires the "
243                                 "components to have specified shapes."));
244     callback(Tuple());
245     return;
246   }
247   if (num_elements == 0) {
248     Tuple tuple;
249     tuple.reserve(num_components());
250     for (int i = 0; i < num_components(); ++i) {
251       // TODO(josh11b,misard): Switch to allocate_output().  Problem is
252       // this breaks the abstraction boundary since we don't *really*
253       // know if and how the Tensors in the tuple we pass to callback
254       // correspond to the outputs of *ctx.  For example, the
255       // ReaderRead Op uses TryDequeue() to get a filename out of a
256       // queue that is used internally by the reader and is not
257       // associated with any output of the ReaderRead.
258       // mrry@ adds:
259       // Maybe we need to pass a std::function<Tensor*(...)> (or
260       // better signature) that calls the appropriate allocator
261       // function in addition to ctx?  (Or support a shim Allocator
262       // that has an internal OpKernelContext*, and dispatches to the
263       // appropriate method?)
264       // misard@ adds:
265       // I don't see that a std::function would help. The problem is
266       // that at this point (allocation time) the system doesn't know
267       // what is going to happen to the element read out of the
268       // queue. As long as we keep the generality that TensorFlow Ops
269       // do their own dynamic allocation in arbitrary C++ code, we
270       // need to preserve robustness to allocating output Tensors with
271       // the 'wrong' attributes, and fixing up with a copy. The only
272       // improvement I can see here in the future would be to support
273       // an optimized case where the queue 'knows' what attributes to
274       // use, and plumbs them through here.
275       Tensor element;
276       Status status = ctx->allocate_temp(component_dtypes_[i],
277                                          ManyOutShape(i, 0), &element);
278       if (!status.ok()) {
279         ctx->SetStatus(status);
280         callback(Tuple());
281         return;
282       }
283       tuple.emplace_back(element);
284     }
285     callback(tuple);
286     return;
287   }
288 
289   CancellationManager* cm = ctx->cancellation_manager();
290   CancellationToken token = cm->get_cancellation_token();
291   bool already_cancelled;
292   {
293     mutex_lock l(mu_);
294     already_cancelled = !cm->RegisterCallback(
295         token, [this, cm, token]() { Cancel(kDequeue, cm, token); });
296     if (!already_cancelled) {
297       // TODO(josh11b): This makes two copies of callback, avoid this if possible.
298       dequeue_attempts_.emplace_back(
299           num_elements, [callback]() { callback(Tuple()); }, ctx, cm, token,
300           [callback, this, allow_small_batch](
301               Attempt* attempt) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
302             int32_t s = queues_[0].size();
303             // Return OutOfRange if closed and there are fewer elements
304             // available than requested.  *Unless* allow_small_batch
305             // is true, in which case we return as many elements as
306             // possible.
307             if (closed_) {
308               if (s == 0 ||
309                   (!allow_small_batch && s < attempt->elements_requested)) {
310                 attempt->context->SetStatus(errors::OutOfRange(
311                     "PriorityQueue '", name_, "' is closed and has ",
312                     "insufficient elements (requested ",
313                     attempt->elements_requested, ", current size ", s, ")"));
314                 return kComplete;
315               }
316             }
317 
318             // The PriorityQueue is expected to always return a
319             // sorted set of entries.  In order to do this, the underlying
320             // queue must have at least this many entries already.
321             // Doing the dynamic thing and pulling out a portion at a
322             // time leads to unordered output in calls to DequeueMany.
323             //
324             // An alternative solution is to store the attempt tuple
325             // entries in an identical priority_queue and push onto
326             // this queue dynamically, then when it is full, do all
327             // the Tensor concatenation at the very end.
328             // TODO(ebrevdo): Change approach if this leads to locking issues.
329             if (s < attempt->elements_requested) {
330               // If we have no elements at all, then wait.
331               // Otherwise proceed if closed and allow small batch is true.
332               // Otherwise wait until we have more enqueued elements.
333               if (s == 0 || !(closed_ && allow_small_batch)) {
334                 return kNoProgress;
335               }
336             }
337 
338             RunResult result = kNoProgress;
339             for (; s > 0; --s) {
340               if (attempt->tuple.empty()) {
341                 // Only allocate tuple when we have something to dequeue
342                 // so we don't use excessive memory when there are many
343                 // blocked dequeue attempts waiting.
344                 attempt->tuple.reserve(num_components());
345                 for (int i = 0; i < num_components(); ++i) {
346                   const TensorShape shape =
347                       ManyOutShape(i, attempt->elements_requested);
348                   Tensor element;
349                   attempt->context->SetStatus(attempt->context->allocate_temp(
350                       component_dtypes_[i], shape, &element));
351                   if (!attempt->context->status().ok()) return kComplete;
352                   attempt->tuple.emplace_back(element);
353                 }
354               }
355               result = kProgress;
356               Tuple tuple;
357               DequeueLocked(attempt->context, &tuple);
358               const int index =
359                   attempt->tuple[0].dim_size(0) - attempt->elements_requested;
360               for (int i = 0; i < num_components(); ++i) {
361                 attempt->context->SetStatus(batch_util::CopyElementToSlice(
362                     std::move(tuple[i]), &attempt->tuple[i], index));
363                 if (!attempt->context->status().ok()) return kComplete;
364               }
365               tuple.clear();
366               --attempt->elements_requested;
367               if (attempt->elements_requested == 0) {
368                 tuple = attempt->tuple;
369                 attempt->done_callback = [callback, tuple]() {
370                   callback(tuple);
371                 };
372                 return kComplete;
373               }
374             }
375             return result;
376           });
377     }
378   }
379   if (!already_cancelled) {
380     FlushUnlocked();
381   } else {
382     ctx->SetStatus(errors::Cancelled("Dequeue operation was cancelled"));
383     callback(Tuple());
384   }
385 }
386 
MatchesNodeDef(const NodeDef & node_def)387 Status PriorityQueue::MatchesNodeDef(const NodeDef& node_def) {
388   if (!MatchesNodeDefOp(node_def, "PriorityQueue").ok() &&
389       !MatchesNodeDefOp(node_def, "PriorityQueueV2").ok()) {
390     return errors::InvalidArgument("Expected PriorityQueue, found ",
391                                    node_def.op());
392   }
393   TF_RETURN_IF_ERROR(MatchesNodeDefCapacity(node_def, capacity_));
394   TF_RETURN_IF_ERROR(MatchesPriorityNodeDefTypes(node_def));
395   TF_RETURN_IF_ERROR(MatchesPriorityNodeDefShapes(node_def));
396   return OkStatus();
397 }
398 
MatchesPriorityNodeDefTypes(const NodeDef & node_def) const399 Status PriorityQueue::MatchesPriorityNodeDefTypes(
400     const NodeDef& node_def) const {
401   DataTypeVector requested_dtypes;
402   TF_RETURN_IF_ERROR(
403       GetNodeAttr(node_def, "component_types", &requested_dtypes));
404   requested_dtypes.insert(requested_dtypes.begin(), DT_INT64);
405   if (requested_dtypes != component_dtypes_) {
406     return errors::InvalidArgument("Shared queue '", name_,
407                                    "' has component types ",
408                                    DataTypeSliceString(component_dtypes_),
409                                    " but requested component types were ",
410                                    DataTypeSliceString(requested_dtypes));
411   }
412   return OkStatus();
413 }
414 
MatchesPriorityNodeDefShapes(const NodeDef & node_def) const415 Status PriorityQueue::MatchesPriorityNodeDefShapes(
416     const NodeDef& node_def) const {
417   std::vector<TensorShape> requested_shapes;
418   TF_RETURN_IF_ERROR(GetNodeAttr(node_def, "shapes", &requested_shapes));
419   requested_shapes.insert(requested_shapes.begin(), TensorShape({}));
420   if (requested_shapes != component_shapes_) {
421     return errors::InvalidArgument("Shared queue '", name_,
422                                    "' has component shapes ",
423                                    ShapeListString(component_shapes_),
424                                    " but requested component shapes were ",
425                                    ShapeListString(requested_shapes));
426   }
427   return OkStatus();
428 }
429 
430 }  // namespace tensorflow
431