xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/map_stage_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <functional>
18 #include <map>
19 #include <mutex>
20 #include <numeric>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/lib/gtl/optional.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/env.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 // Partial Ordering Comparator for Tensor keys containing scalar int64's
38 struct KeyTensorLess {
operator ()tensorflow::__anon59cfcf890111::KeyTensorLess39   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
40     return std::less<int64_t>{}(lhs.scalar<int64_t>()(),
41                                 rhs.scalar<int64_t>()());
42   }
43 };
44 
45 // Key Equality operator for Tensor keys containing scalar int64's
46 struct KeyTensorEqual {
operator ()tensorflow::__anon59cfcf890111::KeyTensorEqual47   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
48     return std::equal_to<int64_t>{}(lhs.scalar<int64_t>()(),
49                                     rhs.scalar<int64_t>()());
50   }
51 };
52 
53 // Hash for Tensor keys containing scalar int64's
54 struct KeyTensorHash {
operator ()tensorflow::__anon59cfcf890111::KeyTensorHash55   std::size_t operator()(const Tensor& key) const {
56     return std::hash<int64_t>{}(key.scalar<int64_t>()());
57   }
58 };
59 
60 // Primary template.
61 template <bool Ordered, typename Data>
62 struct MapTraits;
63 
64 // Partial specialization for ordered.
65 template <typename Data>
66 struct MapTraits<true, Data> {
67   using KeyType = Tensor;
68   using DataType = Data;
69   using MapType = std::map<KeyType, Data, KeyTensorLess>;
70 };
71 
72 // Partial specialization for unordered.
73 template <typename Data>
74 struct MapTraits<false, Data> {
75   using KeyType = Tensor;
76   using DataType = Data;
77   using MapType =
78       std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
79 };
80 
81 // Wrapper around map/unordered_map.
82 template <bool Ordered>
83 class StagingMap : public ResourceBase {
84  public:
85   // Public typedefs
86   using Tuple = std::vector<Tensor>;
87   using OptionalTensor = gtl::optional<Tensor>;
88   using OptionalTuple = std::vector<OptionalTensor>;
89 
90   using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
91   using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
92 
93   using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
94 
95  private:
96   // Private variables
97   DataTypeVector dtypes_ TF_GUARDED_BY(mu_);
98   std::size_t capacity_ TF_GUARDED_BY(mu_);
99   std::size_t memory_limit_ TF_GUARDED_BY(mu_);
100   std::size_t current_bytes_ TF_GUARDED_BY(mu_);
101   tensorflow::mutex mu_;
102   tensorflow::condition_variable not_empty_;
103   tensorflow::condition_variable full_;
104   IncompleteType incomplete_ TF_GUARDED_BY(mu_);
105   MapType map_ TF_GUARDED_BY(mu_);
106 
107  private:
108   // private methods
109 
110   // If map is configured for bounded capacity, notify
111   // waiting inserters that space is now available
notify_inserters_if_bounded()112   void notify_inserters_if_bounded() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
113     if (has_capacity() || has_memory_limit()) {
114       // Notify all inserters. The removal of an element
115       // may make memory available for many inserters
116       // to insert new elements
117       full_.notify_all();
118     }
119   }
120 
121   // Notify all removers waiting to extract values
122   // that data is now available
notify_removers()123   void notify_removers() {
124     // Notify all removers. This is because they are
125     // waiting for specific keys to appear in the map
126     // so we don't know which one to wake up.
127     not_empty_.notify_all();
128   }
129 
has_capacity() const130   bool has_capacity() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
131     return capacity_ > 0;
132   }
133 
has_memory_limit() const134   bool has_memory_limit() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
135     return memory_limit_ > 0;
136   }
137 
would_exceed_memory_limit(std::size_t bytes) const138   bool would_exceed_memory_limit(std::size_t bytes) const
139       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
140     return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
141   }
142 
is_capacity_full() const143   bool is_capacity_full() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
144     return has_capacity() && map_.size() >= capacity_;
145   }
146 
147   // Get number of bytes in the tuple
get_tuple_bytes(const Tuple & tuple)148   std::size_t get_tuple_bytes(const Tuple& tuple) {
149     return std::accumulate(tuple.begin(), tuple.end(),
150                            static_cast<std::size_t>(0),
151                            [](const std::size_t& lhs, const Tensor& rhs) {
152                              return lhs + rhs.TotalBytes();
153                            });
154   }
155 
156   // Get number of bytes in the incomplete tuple
get_tuple_bytes(const OptionalTuple & tuple)157   std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
158     return std::accumulate(
159         tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
160         [](const std::size_t& lhs, const OptionalTensor& rhs) {
161           return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
162         });
163   }
164 
165   // Check that the index is within bounds
check_index(const Tensor & key,std::size_t index)166   Status check_index(const Tensor& key, std::size_t index)
167       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
168     if (index >= dtypes_.size()) {
169       return Status(errors::InvalidArgument(
170           "Index '", index, "' for key '", key.scalar<int64_t>()(),
171           "' was out of bounds '", dtypes_.size(), "'."));
172     }
173 
174     return OkStatus();
175   }
176 
copy_or_move_tensors(OptionalTuple * map_tuple,const Tensor & key,const Tensor & indices,Tuple * output,bool copy=false)177   Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
178                               const Tensor& indices, Tuple* output,
179                               bool copy = false)
180       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
181     auto findices = indices.flat<int>();
182 
183     // Return values at specified indices
184     for (std::size_t i = 0; i < findices.dimension(0); ++i) {
185       std::size_t index = findices(i);
186 
187       TF_RETURN_IF_ERROR(check_index(key, index));
188 
189       // Insist on a value present at the specified index
190       if (!(*map_tuple)[index].has_value()) {
191         return Status(errors::InvalidArgument(
192             "Tensor at index '", index, "' for key '", key.scalar<int64_t>()(),
193             "' has already been removed."));
194       }
195 
196       // Copy the contained tensor and
197       // remove from the OptionalTuple
198       output->push_back((*map_tuple)[index].value());
199 
200       // Clear out the entry if we're not copying (moving)
201       if (!copy) {
202         (*map_tuple)[index].reset();
203       }
204     }
205 
206     return OkStatus();
207   }
208 
209   // Check that the optional value at the specified index
210   // is uninitialized
check_index_uninitialized(const Tensor & key,std::size_t index,const OptionalTuple & tuple)211   Status check_index_uninitialized(const Tensor& key, std::size_t index,
212                                    const OptionalTuple& tuple)
213       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
214     if (tuple[index].has_value()) {
215       return errors::InvalidArgument("The tensor for index '", index,
216                                      "' for key '", key.scalar<int64_t>()(),
217                                      "' was already initialized '",
218                                      dtypes_.size(), "'.");
219     }
220 
221     return OkStatus();
222   }
223 
224   // Check that the indices are strictly ordered
check_index_ordering(const Tensor & indices)225   Status check_index_ordering(const Tensor& indices) {
226     if (indices.NumElements() == 0) {
227       return errors::InvalidArgument("Indices are empty");
228     }
229 
230     auto findices = indices.flat<int>();
231 
232     for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
233       if (findices(i) < findices(i + 1)) {
234         continue;
235       }
236 
237       return errors::InvalidArgument("Indices are not strictly ordered");
238     }
239 
240     return OkStatus();
241   }
242 
243   // Check bytes are within memory limits memory limits
check_memory_limit(std::size_t bytes)244   Status check_memory_limit(std::size_t bytes)
245       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
246     if (has_memory_limit() && bytes > memory_limit_) {
247       return errors::ResourceExhausted(
248           "Attempted to insert tensors with combined size of '", bytes,
249           "' bytes into Staging Area with a memory limit of '", memory_limit_,
250           "'.");
251     }
252 
253     return OkStatus();
254   }
255 
256   // Insert incomplete data into the Barrier
put_incomplete(const KeyType & key,const Tensor & indices,OptionalTuple * tuple,tensorflow::mutex_lock * lock)257   Status put_incomplete(const KeyType& key, const Tensor& indices,
258                         OptionalTuple* tuple, tensorflow::mutex_lock* lock)
259       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
260     auto findices = indices.flat<int>();
261 
262     // Search for the key in our incomplete set
263     auto it = incomplete_.find(key);
264 
265     // Check that the tuple fits within the memory limit
266     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
267     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
268 
269     // Wait until we don't exceed the memory limit
270     while (would_exceed_memory_limit(tuple_bytes)) {
271       full_.wait(*lock);
272     }
273 
274     // This key isn't present in the incomplete set
275     // Create OptionalTuple and insert
276     if (it == incomplete_.end()) {
277       OptionalTuple empty(dtypes_.size());
278 
279       // Initialize empty tuple with given dta
280       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
281         std::size_t index = findices(i);
282         TF_RETURN_IF_ERROR(check_index(key, index));
283 
284         // Assign tuple at this index
285         empty[index] = std::move((*tuple)[i]);
286       }
287 
288       // Insert into incomplete map
289       incomplete_.insert({key, std::move(empty)});
290 
291       // Increment size
292       current_bytes_ += tuple_bytes;
293     }
294     // Found an entry in the incomplete index
295     // Update with given data and insert complete entries
296     // into the main map
297     else {
298       // Reference existing incomplete tuple
299       OptionalTuple& present = it->second;
300 
301       // Assign given data
302       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
303         std::size_t index = findices(i);
304         TF_RETURN_IF_ERROR(check_index(key, index));
305         TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
306 
307         // Assign tuple at this index
308         present[index] = std::move((*tuple)[i]);
309       }
310 
311       // Increment size
312       current_bytes_ += tuple_bytes;
313 
314       // Do we have values at all tuple elements?
315       bool complete =
316           std::all_of(present.begin(), present.end(),
317                       [](const OptionalTensor& v) { return v.has_value(); });
318 
319       // If so, put the tuple in the actual map
320       if (complete) {
321         OptionalTuple insert_tuple = std::move(it->second);
322 
323         // Remove from incomplete
324         incomplete_.erase(it);
325 
326         TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
327       }
328     }
329 
330     return OkStatus();
331   }
332 
333   // Does the insertion into the actual staging area
put_complete(const KeyType & key,OptionalTuple * tuple)334   Status put_complete(const KeyType& key, OptionalTuple* tuple)
335       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
336     // Insert key and tuples into the map
337     map_.insert({key, std::move(*tuple)});
338 
339     notify_removers();
340 
341     return OkStatus();
342   }
343 
344  public:
345   // public methods
StagingMap(const DataTypeVector & dtypes,std::size_t capacity,std::size_t memory_limit)346   explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
347                       std::size_t memory_limit)
348       : dtypes_(dtypes),
349         capacity_(capacity),
350         memory_limit_(memory_limit),
351         current_bytes_(0) {}
352 
put(KeyType * key,const Tensor * indices,OptionalTuple * tuple)353   Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
354     tensorflow::mutex_lock lock(mu_);
355 
356     // Sanity check the indices
357     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
358 
359     // Handle incomplete inserts
360     if (indices->NumElements() != dtypes_.size()) {
361       return put_incomplete(*key, *indices, tuple, &lock);
362     }
363 
364     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
365     // Check that tuple_bytes fits within the memory limit
366     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
367 
368     // Wait until there's space for insertion.
369     while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
370       full_.wait(lock);
371     }
372 
373     // Do the put operation
374     TF_RETURN_IF_ERROR(put_complete(*key, tuple));
375 
376     // Update the current size
377     current_bytes_ += tuple_bytes;
378 
379     return OkStatus();
380   }
381 
get(const KeyType * key,const Tensor * indices,Tuple * tuple)382   Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
383     tensorflow::mutex_lock lock(mu_);
384 
385     // Sanity check the indices
386     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
387 
388     typename MapType::iterator it;
389 
390     // Wait until the element with the requested key is present
391     while ((it = map_.find(*key)) == map_.end()) {
392       not_empty_.wait(lock);
393     }
394 
395     TF_RETURN_IF_ERROR(
396         copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
397 
398     // Update bytes in the Staging Area
399     current_bytes_ -= get_tuple_bytes(*tuple);
400 
401     return OkStatus();
402   }
403 
pop(const KeyType * key,const Tensor * indices,Tuple * tuple)404   Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
405     tensorflow::mutex_lock lock(mu_);
406 
407     // Sanity check the indices
408     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
409 
410     typename MapType::iterator it;
411 
412     // Wait until the element with the requested key is present
413     while ((it = map_.find(*key)) == map_.end()) {
414       not_empty_.wait(lock);
415     }
416 
417     TF_RETURN_IF_ERROR(
418         copy_or_move_tensors(&it->second, *key, *indices, tuple));
419 
420     // Remove entry if all the values have been consumed
421     if (!std::any_of(
422             it->second.begin(), it->second.end(),
423             [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
424       map_.erase(it);
425     }
426 
427     // Update bytes in the Staging Area
428     current_bytes_ -= get_tuple_bytes(*tuple);
429 
430     notify_inserters_if_bounded();
431 
432     return OkStatus();
433   }
434 
popitem(KeyType * key,const Tensor * indices,Tuple * tuple)435   Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
436     tensorflow::mutex_lock lock(mu_);
437 
438     // Sanity check the indices
439     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
440 
441     // Wait until map is not empty
442     while (this->map_.empty()) {
443       not_empty_.wait(lock);
444     }
445 
446     // Move from the first element and erase it
447 
448     auto it = map_.begin();
449 
450     TF_RETURN_IF_ERROR(
451         copy_or_move_tensors(&it->second, *key, *indices, tuple));
452 
453     *key = it->first;
454 
455     // Remove entry if all the values have been consumed
456     if (!std::any_of(
457             it->second.begin(), it->second.end(),
458             [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
459       map_.erase(it);
460     }
461 
462     // Update bytes in the Staging Area
463     current_bytes_ -= get_tuple_bytes(*tuple);
464 
465     notify_inserters_if_bounded();
466 
467     return OkStatus();
468   }
469 
clear()470   Status clear() {
471     tensorflow::mutex_lock lock(mu_);
472     map_.clear();
473     incomplete_.clear();
474     current_bytes_ = 0;
475 
476     notify_inserters_if_bounded();
477 
478     return OkStatus();
479   }
480 
incomplete_size()481   std::size_t incomplete_size() {
482     tensorflow::mutex_lock lock(mu_);
483     return incomplete_.size();
484   }
485 
size()486   std::size_t size() {
487     tensorflow::mutex_lock lock(mu_);
488     return map_.size();
489   }
490 
DebugString() const491   string DebugString() const override { return "StagingMap"; }
492 };
493 
494 template <bool Ordered>
GetStagingMap(OpKernelContext * ctx,const NodeDef & ndef,StagingMap<Ordered> ** map)495 Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
496                      StagingMap<Ordered>** map) {
497   auto rm = ctx->resource_manager();
498   ContainerInfo cinfo;
499 
500   // Lambda for creating the Staging Area
501   auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
502     DataTypeVector dtypes;
503     int64_t capacity;
504     int64_t memory_limit;
505     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
506     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
507     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
508     *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
509     return OkStatus();
510   };
511 
512   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
513   TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
514       cinfo.container(), cinfo.name(), map, create_fn));
515   return OkStatus();
516 }
517 
518 template <bool Ordered>
519 class MapStageOp : public OpKernel {
520  public:
MapStageOp(OpKernelConstruction * ctx)521   explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
522 
Compute(OpKernelContext * ctx)523   void Compute(OpKernelContext* ctx) override {
524     StagingMap<Ordered>* map = nullptr;
525     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
526     core::ScopedUnref scope(map);
527     typename StagingMap<Ordered>::OptionalTuple tuple;
528 
529     const Tensor* key_tensor;
530     const Tensor* indices_tensor;
531     OpInputList values_tensor;
532 
533     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
534     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
535     OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
536     OP_REQUIRES(ctx, key_tensor->NumElements() > 0,
537                 errors::InvalidArgument("key must not be empty"));
538 
539     OP_REQUIRES(ctx, key_tensor->NumElements() == 1,
540                 errors::InvalidArgument(
541                     "key must be an int64 scalar, got tensor with shape: ",
542                     key_tensor->shape()));
543 
544     // Create copy for insertion into Staging Area
545     Tensor key(*key_tensor);
546 
547     // Create the tuple to store
548     for (std::size_t i = 0; i < values_tensor.size(); ++i) {
549       tuple.push_back(values_tensor[i]);
550     }
551 
552     // Store the tuple in the map
553     OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
554   }
555 };
556 
557 REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
558 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
559                         MapStageOp<true>);
560 
561 REGISTER_KERNEL_BUILDER(Name("MapStage")
562                             .HostMemory("key")
563                             .HostMemory("indices")
564                             .Device(DEVICE_DEFAULT),
565                         MapStageOp<false>);
566 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
567                             .HostMemory("key")
568                             .HostMemory("indices")
569                             .Device(DEVICE_DEFAULT),
570                         MapStageOp<true>);
571 
572 template <bool Ordered>
573 class MapUnstageOp : public OpKernel {
574  public:
MapUnstageOp(OpKernelConstruction * ctx)575   explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
576 
577   // Using this op in such a way that it blocks forever
578   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)579   void Compute(OpKernelContext* ctx) override {
580     StagingMap<Ordered>* map = nullptr;
581     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
582     core::ScopedUnref scope(map);
583     typename StagingMap<Ordered>::Tuple tuple;
584 
585     const Tensor* key_tensor;
586     const Tensor* indices_tensor;
587 
588     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
589     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
590     OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
591 
592     OP_REQUIRES(
593         ctx, tuple.size() == indices_tensor->NumElements(),
594         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
595                                 " vs. ", indices_tensor->NumElements()));
596 
597     for (std::size_t i = 0; i < tuple.size(); ++i) {
598       ctx->set_output(i, tuple[i]);
599     }
600   }
601 };
602 
603 REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
604                         MapUnstageOp<false>);
605 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
606                         MapUnstageOp<true>);
607 
608 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
609                             .HostMemory("key")
610                             .HostMemory("indices")
611                             .Device(DEVICE_DEFAULT),
612                         MapUnstageOp<false>);
613 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
614                             .HostMemory("key")
615                             .HostMemory("indices")
616                             .Device(DEVICE_DEFAULT),
617                         MapUnstageOp<true>);
618 
619 template <bool Ordered>
620 class MapPeekOp : public OpKernel {
621  public:
MapPeekOp(OpKernelConstruction * ctx)622   explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
623 
624   // Using this op in such a way that it blocks forever
625   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)626   void Compute(OpKernelContext* ctx) override {
627     StagingMap<Ordered>* map = nullptr;
628     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
629     core::ScopedUnref scope(map);
630     typename StagingMap<Ordered>::Tuple tuple;
631 
632     const Tensor* key_tensor;
633     const Tensor* indices_tensor;
634 
635     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
636     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
637     OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
638 
639     OP_REQUIRES(
640         ctx, tuple.size() == indices_tensor->NumElements(),
641         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
642                                 " vs. ", indices_tensor->NumElements()));
643 
644     for (std::size_t i = 0; i < tuple.size(); ++i) {
645       ctx->set_output(i, tuple[i]);
646     }
647   }
648 };
649 
650 REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
651 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
652                         MapPeekOp<true>);
653 
654 REGISTER_KERNEL_BUILDER(
655     Name("MapPeek").HostMemory("key").HostMemory("indices").Device(
656         DEVICE_DEFAULT),
657     MapPeekOp<false>);
658 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
659                             .HostMemory("key")
660                             .HostMemory("indices")
661                             .Device(DEVICE_DEFAULT),
662                         MapPeekOp<true>);
663 
664 template <bool Ordered>
665 class MapUnstageNoKeyOp : public OpKernel {
666  public:
MapUnstageNoKeyOp(OpKernelConstruction * ctx)667   explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
668 
669   // Using this op in such a way that it blocks forever
670   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)671   void Compute(OpKernelContext* ctx) override {
672     StagingMap<Ordered>* map = nullptr;
673     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
674     core::ScopedUnref scope(map);
675 
676     // Pop a random (key, value) off the map
677     typename StagingMap<Ordered>::KeyType key;
678     typename StagingMap<Ordered>::Tuple tuple;
679 
680     const Tensor* indices_tensor;
681 
682     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
683     OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
684 
685     // Allocate a key tensor and assign the key as the first output
686     ctx->set_output(0, key);
687 
688     // Set the rest of the outputs to the tuple Tensors
689     OP_REQUIRES(
690         ctx, tuple.size() == indices_tensor->NumElements(),
691         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
692                                 " vs. ", indices_tensor->NumElements()));
693 
694     for (std::size_t i = 0; i < tuple.size(); ++i) {
695       ctx->set_output(i + 1, tuple[i]);
696     }
697   }
698 };
699 
700 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
701                         MapUnstageNoKeyOp<false>);
702 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
703                         MapUnstageNoKeyOp<true>);
704 
705 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
706                             .HostMemory("key")
707                             .HostMemory("indices")
708                             .Device(DEVICE_DEFAULT),
709                         MapUnstageNoKeyOp<false>);
710 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
711                             .HostMemory("key")
712                             .HostMemory("indices")
713                             .Device(DEVICE_DEFAULT),
714                         MapUnstageNoKeyOp<true>);
715 
716 template <bool Ordered>
717 class MapSizeOp : public OpKernel {
718  public:
MapSizeOp(OpKernelConstruction * ctx)719   explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
720 
Compute(OpKernelContext * ctx)721   void Compute(OpKernelContext* ctx) override {
722     StagingMap<Ordered>* map = nullptr;
723     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
724     core::ScopedUnref scope(map);
725 
726     // Allocate size output tensor
727     Tensor* size = nullptr;
728     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
729 
730     // Set it to the actual size
731     size->scalar<int32>().setConstant(map->size());
732   }
733 };
734 
735 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
736 REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
737                         MapSizeOp<true>);
738 
739 REGISTER_KERNEL_BUILDER(
740     Name("MapSize").Device(DEVICE_DEFAULT).HostMemory("size"),
741     MapSizeOp<false>);
742 REGISTER_KERNEL_BUILDER(
743     Name("OrderedMapSize").Device(DEVICE_DEFAULT).HostMemory("size"),
744     MapSizeOp<true>);
745 
746 template <bool Ordered>
747 class MapIncompleteSizeOp : public OpKernel {
748  public:
MapIncompleteSizeOp(OpKernelConstruction * ctx)749   explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
750 
Compute(OpKernelContext * ctx)751   void Compute(OpKernelContext* ctx) override {
752     StagingMap<Ordered>* map = nullptr;
753     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
754     core::ScopedUnref scope(map);
755 
756     // Allocate size output tensor
757     Tensor* size = nullptr;
758     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
759 
760     // Set it to the actual size
761     size->scalar<int32>().setConstant(map->incomplete_size());
762   }
763 };
764 
765 REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
766                         MapIncompleteSizeOp<false>);
767 REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
768                         MapIncompleteSizeOp<true>);
769 
770 REGISTER_KERNEL_BUILDER(
771     Name("MapIncompleteSize").Device(DEVICE_DEFAULT).HostMemory("size"),
772     MapIncompleteSizeOp<false>);
773 REGISTER_KERNEL_BUILDER(
774     Name("OrderedMapIncompleteSize").Device(DEVICE_DEFAULT).HostMemory("size"),
775     MapIncompleteSizeOp<true>);
776 
777 template <bool Ordered>
778 class MapClearOp : public OpKernel {
779  public:
MapClearOp(OpKernelConstruction * ctx)780   explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
781 
Compute(OpKernelContext * ctx)782   void Compute(OpKernelContext* ctx) override {
783     StagingMap<Ordered>* map = nullptr;
784     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
785     core::ScopedUnref scope(map);
786 
787     OP_REQUIRES_OK(ctx, map->clear());
788   }
789 };
790 
791 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
792 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
793                         MapClearOp<true>);
794 
795 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_DEFAULT),
796                         MapClearOp<false>);
797 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_DEFAULT),
798                         MapClearOp<true>);
799 
800 }  // namespace
801 }  // namespace tensorflow
802