1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <map>
16 
17 #include "tensorflow/core/common_runtime/function.h"
18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
19 #include "tensorflow/core/data/captured_function.h"
20 #include "tensorflow/core/data/dataset_utils.h"
21 #include "tensorflow/core/framework/dataset.h"
22 #include "tensorflow/core/framework/partial_tensor_shape.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/kernels/data/window_dataset.h"
25 #include "tensorflow/core/lib/random/random.h"
26 
27 namespace tensorflow {
28 namespace data {
29 namespace experimental {
30 namespace {
31 
32 class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
33  public:
GroupByWindowDatasetOp(OpKernelConstruction * ctx)34   explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
35       : UnaryDatasetOpKernel(ctx) {
36     OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "key_func", /*params=*/{},
37                                                  &key_func_metadata_));
38     OP_REQUIRES_OK(ctx,
39                    FunctionMetadata::Create(ctx, "reduce_func", /*params=*/{},
40                                             &reduce_func_metadata_));
41     OP_REQUIRES_OK(
42         ctx, FunctionMetadata::Create(ctx, "window_size_func", /*params=*/{},
43                                       &window_size_func_metadata_));
44     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
45     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
46   }
47 
MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)48   void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
49                    DatasetBase** output) override {
50     std::unique_ptr<CapturedFunction> captured_key_func;
51     OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, key_func_metadata_,
52                                                  "key_func_other_arguments",
53                                                  &captured_key_func));
54 
55     std::unique_ptr<CapturedFunction> captured_reduce_func;
56     OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, reduce_func_metadata_,
57                                                  "reduce_func_other_arguments",
58                                                  &captured_reduce_func));
59 
60     std::unique_ptr<CapturedFunction> captured_window_size_func;
61     OP_REQUIRES_OK(ctx,
62                    CapturedFunction::Create(ctx, window_size_func_metadata_,
63                                             "window_size_func_other_arguments",
64                                             &captured_window_size_func));
65 
66     *output = new Dataset(ctx, input, std::move(captured_key_func),
67                           std::move(captured_reduce_func),
68                           std::move(captured_window_size_func), output_types_,
69                           output_shapes_);
70   }
71 
72  private:
73   class Dataset : public DatasetBase {
74    public:
Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_key_func,std::unique_ptr<CapturedFunction> captured_reduce_func,std::unique_ptr<CapturedFunction> captured_window_size_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)75     Dataset(OpKernelContext* ctx, const DatasetBase* input,
76             std::unique_ptr<CapturedFunction> captured_key_func,
77             std::unique_ptr<CapturedFunction> captured_reduce_func,
78             std::unique_ptr<CapturedFunction> captured_window_size_func,
79             const DataTypeVector& output_types,
80             const std::vector<PartialTensorShape>& output_shapes)
81         : DatasetBase(DatasetContext(ctx)),
82           input_(input),
83           captured_key_func_(std::move(captured_key_func)),
84           captured_reduce_func_(std::move(captured_reduce_func)),
85           captured_window_size_func_(std::move(captured_window_size_func)),
86           output_types_(output_types),
87           output_shapes_(output_shapes) {
88       input_->Ref();
89     }
90 
~Dataset()91     ~Dataset() override { input_->Unref(); }
92 
MakeIteratorInternal(const string & prefix) const93     std::unique_ptr<IteratorBase> MakeIteratorInternal(
94         const string& prefix) const override {
95       return std::make_unique<Iterator>(
96           Iterator::Params{this, strings::StrCat(prefix, "::GroupByWindow")});
97     }
98 
output_dtypes() const99     const DataTypeVector& output_dtypes() const override {
100       return output_types_;
101     }
output_shapes() const102     const std::vector<PartialTensorShape>& output_shapes() const override {
103       return output_shapes_;
104     }
105 
DebugString() const106     string DebugString() const override {
107       return "GroupByWindowDatasetOp::Dataset";
108     }
109 
CardinalityInternal() const110     int64_t CardinalityInternal() const override {
111       int64_t n = input_->Cardinality();
112       if (n == kInfiniteCardinality) {
113         return n;
114       }
115       return kUnknownCardinality;
116     }
117 
InputDatasets(std::vector<const DatasetBase * > * inputs) const118     Status InputDatasets(
119         std::vector<const DatasetBase*>* inputs) const override {
120       inputs->push_back(input_);
121       return OkStatus();
122     }
123 
CheckExternalState() const124     Status CheckExternalState() const override {
125       TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
126       TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
127       TF_RETURN_IF_ERROR(captured_window_size_func_->CheckExternalState());
128       return input_->CheckExternalState();
129     }
130 
131    protected:
AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const132     Status AsGraphDefInternal(SerializationContext* ctx,
133                               DatasetGraphDefBuilder* b,
134                               Node** output) const override {
135       Node* input_graph_node = nullptr;
136       TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
137 
138       std::vector<Node*> key_func_other_arguments_node;
139       DataTypeVector key_func_other_arguments_types;
140       TF_RETURN_IF_ERROR(
141           captured_key_func_->AddToGraph(ctx, b, &key_func_other_arguments_node,
142                                          &key_func_other_arguments_types));
143 
144       std::vector<Node*> reduce_func_other_arguments_node;
145       DataTypeVector reduce_func_other_arguments_types;
146       TF_RETURN_IF_ERROR(captured_reduce_func_->AddToGraph(
147           ctx, b, &reduce_func_other_arguments_node,
148           &reduce_func_other_arguments_types));
149 
150       std::vector<Node*> window_size_func_other_arguments_node;
151       DataTypeVector window_size_func_other_arguments_types;
152       TF_RETURN_IF_ERROR(captured_window_size_func_->AddToGraph(
153           ctx, b, &window_size_func_other_arguments_node,
154           &window_size_func_other_arguments_types));
155 
156       AttrValue key_func;
157       b->BuildAttrValue(captured_key_func_->func(), &key_func);
158       AttrValue reduce_func;
159       b->BuildAttrValue(captured_reduce_func_->func(), &reduce_func);
160       AttrValue window_size_func;
161       b->BuildAttrValue(captured_window_size_func_->func(), &window_size_func);
162 
163       AttrValue key_func_other_arguments_types_attr;
164       b->BuildAttrValue(key_func_other_arguments_types,
165                         &key_func_other_arguments_types_attr);
166       AttrValue reduce_func_other_arguments_types_attr;
167       b->BuildAttrValue(reduce_func_other_arguments_types,
168                         &reduce_func_other_arguments_types_attr);
169       AttrValue window_size_func_other_arguments_types_attr;
170       b->BuildAttrValue(window_size_func_other_arguments_types,
171                         &window_size_func_other_arguments_types_attr);
172 
173       TF_RETURN_IF_ERROR(b->AddDataset(
174           this, {{0, input_graph_node}},
175           {{1, key_func_other_arguments_node},
176            {2, reduce_func_other_arguments_node},
177            {3, window_size_func_other_arguments_node}},
178           {{"key_func", key_func},
179            {"reduce_func", reduce_func},
180            {"window_size_func", window_size_func},
181            {"Tkey_func_other_arguments", key_func_other_arguments_types_attr},
182            {"Treduce_func_other_arguments",
183             reduce_func_other_arguments_types_attr},
184            {"Twindow_size_func_other_arguments",
185             window_size_func_other_arguments_types_attr}},
186           output));
187       return OkStatus();
188     }
189 
190    private:
191     class Iterator : public DatasetIterator<Dataset> {
192      public:
Iterator(const Params & params)193       explicit Iterator(const Params& params)
194           : DatasetIterator<Dataset>(params) {}
195 
Initialize(IteratorContext * ctx)196       Status Initialize(IteratorContext* ctx) override {
197         TF_RETURN_IF_ERROR(
198             dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
199         TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(
200             ctx, &instantiated_key_func_));
201         TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(
202             ctx, &instantiated_reduce_func_));
203         TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Instantiate(
204             ctx, &instantiated_window_size_func_));
205         return OkStatus();
206       }
207 
GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)208       Status GetNextInternal(IteratorContext* ctx,
209                              std::vector<Tensor>* out_tensors,
210                              bool* end_of_sequence) override {
211         mutex_lock l(mu_);
212         do {
213           if (current_group_iterator_) {
214             // We are currently processing a group, so try to get the
215             // next element.
216             bool end_of_group;
217             TF_RETURN_IF_ERROR(current_group_iterator_->GetNext(
218                 MakeNestedIteratorContext(ctx), out_tensors, &end_of_group));
219             if (!end_of_group) {
220               // Produce the subelement as output.
221               *end_of_sequence = false;
222               return OkStatus();
223             }
224             // We have reached the end of the current group, so maybe move on
225             // to the next group.
226             current_group_iterator_.reset();
227             groups_.erase(current_key_);
228           }
229 
230           // Iterate through the input dataset until we get a full
231           // group, or reach the end.
232           while (!end_of_input_) {
233             std::vector<Tensor> next_input_element;
234             TF_RETURN_IF_ERROR(
235                 input_impl_->GetNext(MakeNestedIteratorContext(ctx),
236                                      &next_input_element, &end_of_input_));
237 
238             if (!end_of_input_) {
239               // Run the key function on the input element to identify its
240               // group.
241               std::vector<Tensor> key_func_output;
242               TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs(
243                   ctx, next_input_element, &key_func_output, model_node()));
244 
245               if (key_func_output.size() != 1 ||
246                   key_func_output[0].dtype() != DT_INT64 ||
247                   key_func_output[0].NumElements() != 1) {
248                 // TODO(b/78665031): Support non-int64 keys.
249                 return errors::InvalidArgument(
250                     "`key_func` must return a scalar int64.");
251               }
252               const int64_t key = key_func_output[0].scalar<int64_t>()();
253 
254               if (window_sizes_.find(key) == window_sizes_.end()) {
255                 // Run the window size function on the key to identify its
256                 // window size.
257                 std::vector<Tensor> window_size_func_output;
258                 TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run(
259                     ctx, std::move(key_func_output), &window_size_func_output,
260                     model_node()));
261 
262                 if (window_size_func_output.size() != 1 ||
263                     window_size_func_output[0].dtype() != DT_INT64 ||
264                     window_size_func_output[0].NumElements() != 1) {
265                   // TODO(mrry): Support non-int64 window sizes.
266                   return errors::InvalidArgument(
267                       "`window_size_func` must return a scalar int64.");
268                 }
269                 const int64_t window_size =
270                     window_size_func_output[0].scalar<int64_t>()();
271                 if (window_size <= 0) {
272                   return errors::InvalidArgument(
273                       "Window size must be greater than zero, but got ",
274                       window_size, ".");
275                 }
276                 window_sizes_[key] = window_size;
277               }
278 
279               const int64_t window_size = window_sizes_[key];
280 
281               std::vector<std::vector<Tensor>>& group = groups_[key];
282               group.push_back(std::move(next_input_element));
283 
284               if (group.size() == window_size) {
285                 current_key_ = key;
286                 TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, key));
287                 break;
288               }
289             }
290           }
291 
292           if (end_of_input_) {
293             if (!groups_.empty()) {
294               // We have consumed all of the input, so flush an
295               // arbitrarily chosen group.
296               current_key_ = groups_.begin()->first;
297               TF_RETURN_IF_ERROR(
298                   StartFlushingGroup(ctx, groups_.begin()->first));
299             }
300           }
301         } while (current_group_iterator_ || !end_of_input_);
302 
303         *end_of_sequence = true;
304         return OkStatus();
305       }
306 
307      protected:
CreateNode(IteratorContext * ctx,model::Node::Args args) const308       std::shared_ptr<model::Node> CreateNode(
309           IteratorContext* ctx, model::Node::Args args) const override {
310         return model::MakeUnknownRatioNode(std::move(args));
311       }
312 
SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)313       Status SaveInternal(SerializationContext* ctx,
314                           IteratorStateWriter* writer) override {
315         TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
316             dataset()->captured_key_func_->CheckExternalState()));
317         TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
318             dataset()->captured_reduce_func_->CheckExternalState()));
319         TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
320             dataset()->captured_window_size_func_->CheckExternalState()));
321         mutex_lock l(mu_);
322         TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
323 
324         if (end_of_input_) {
325           TF_RETURN_IF_ERROR(
326               writer->WriteScalar(full_name("end_of_input"), ""));
327         }
328 
329         // Saving groups_
330         if (!groups_.empty()) {
331           TF_RETURN_IF_ERROR(
332               writer->WriteScalar(full_name("groups_size"), groups_.size()));
333           int idx = 0;
334           for (auto it = groups_.begin(); it != groups_.end(); it++) {
335             int64_t key = it->first;
336             TF_RETURN_IF_ERROR(writer->WriteScalar(
337                 full_name(strings::StrCat("groups_[", idx, "]->key")), key));
338             TF_RETURN_IF_ERROR(SaveGroup(
339                 writer, full_name(strings::StrCat("groups_[", idx, "]")),
340                 it->second));
341             idx++;
342           }
343         }
344 
345         // Saving window_sizes_
346         if (!window_sizes_.empty()) {
347           TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("window_sizes_size"),
348                                                  window_sizes_.size()));
349           int idx = 0;
350           for (auto it = window_sizes_.begin(); it != window_sizes_.end();
351                it++) {
352             TF_RETURN_IF_ERROR(writer->WriteScalar(
353                 full_name(strings::StrCat("window_sizes_[", idx, "]->key")),
354                 it->first));
355             TF_RETURN_IF_ERROR(writer->WriteScalar(
356                 full_name(strings::StrCat("window_sizes_[", idx, "]->value")),
357                 it->second));
358             idx++;
359           }
360         }
361 
362         if (current_group_iterator_) {
363           TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_group_iterator_));
364 
365           // Saving current_key_
366           TF_RETURN_IF_ERROR(
367               writer->WriteScalar(full_name("current_key"), current_key_));
368         } else {
369           TF_RETURN_IF_ERROR(writer->WriteScalar(
370               full_name("current_iterator_not_initialized"), ""));
371         }
372         TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("group_counter"),
373                                                group_counter_ - 1));
374         return OkStatus();
375       }
376 
RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)377       Status RestoreInternal(IteratorContext* ctx,
378                              IteratorStateReader* reader) override {
379         mutex_lock l(mu_);
380         TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
381 
382         if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
383 
384         // Restoring groups_
385         if (reader->Contains(full_name("groups_size"))) {
386           int64_t size;
387           TF_RETURN_IF_ERROR(
388               reader->ReadScalar(full_name("groups_size"), &size));
389           for (int idx = 0; idx < size; idx++) {
390             int64_t key;
391             TF_RETURN_IF_ERROR(reader->ReadScalar(
392                 full_name(strings::StrCat("groups_[", idx, "]->key")), &key));
393             std::vector<std::vector<Tensor>> group;
394             TF_RETURN_IF_ERROR(RestoreGroup(
395                 ctx, reader, full_name(strings::StrCat("groups_[", idx, "]")),
396                 &group));
397             groups_[key] = group;
398           }
399         }
400 
401         // Restoring window_sizes_
402         if (reader->Contains(full_name("window_sizes_size"))) {
403           int64_t size;
404           TF_RETURN_IF_ERROR(
405               reader->ReadScalar(full_name("window_sizes_size"), &size));
406           for (int idx = 0; idx < size; idx++) {
407             int64_t key;
408             TF_RETURN_IF_ERROR(reader->ReadScalar(
409                 full_name(strings::StrCat("window_sizes_[", idx, "]->key")),
410                 &key));
411             TF_RETURN_IF_ERROR(reader->ReadScalar(
412                 full_name(strings::StrCat("window_sizes_[", idx, "]->value")),
413                 &window_sizes_[key]));
414           }
415         }
416 
417         // Group counter needs to be restored before current group iterator.
418         TF_RETURN_IF_ERROR(
419             reader->ReadScalar(full_name("group_counter"), &group_counter_));
420 
421         if (reader->Contains(full_name("current_iterator_not_initialized"))) {
422           current_group_iterator_.reset();
423         } else {
424           // Restore current_key_
425           TF_RETURN_IF_ERROR(
426               reader->ReadScalar(full_name("current_key"), &current_key_));
427 
428           // Initialize current_group_iterator_
429           TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, current_key_));
430           // Restore current_group_iterator_ state
431           TF_RETURN_IF_ERROR(
432               RestoreInput(ctx, reader, current_group_iterator_));
433         }
434         return OkStatus();
435       }
436 
437      private:
SaveGroup(IteratorStateWriter * writer,const string & name,const std::vector<std::vector<Tensor>> & group)438       Status SaveGroup(IteratorStateWriter* writer, const string& name,
439                        const std::vector<std::vector<Tensor>>& group)
440           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
441         TF_RETURN_IF_ERROR(
442             writer->WriteScalar(strings::StrCat(name, "_size"), group.size()));
443         for (int i = 0; i < group.size(); i++) {
444           TF_RETURN_IF_ERROR(writer->WriteScalar(
445               strings::StrCat(name, "[", i, "]_size"), group[i].size()));
446           for (int j = 0; j < group[i].size(); j++) {
447             TF_RETURN_IF_ERROR(writer->WriteTensor(
448                 strings::StrCat(name, "[", i, "][", j, "]"), group[i][j]));
449           }
450         }
451         return OkStatus();
452       }
453 
RestoreGroup(IteratorContext * ctx,IteratorStateReader * reader,const string & name,std::vector<std::vector<Tensor>> * group)454       Status RestoreGroup(IteratorContext* ctx, IteratorStateReader* reader,
455                           const string& name,
456                           std::vector<std::vector<Tensor>>* group)
457           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
458         int64_t group_size;
459         TF_RETURN_IF_ERROR(
460             reader->ReadScalar(strings::StrCat(name, "_size"), &group_size));
461         group->resize(group_size);
462         for (int i = 0; i < group_size; i++) {
463           int64_t vector_size;
464           TF_RETURN_IF_ERROR(reader->ReadScalar(
465               strings::StrCat(name, "[", i, "]_size"), &vector_size));
466           group->at(i).resize(vector_size);
467           for (int j = 0; j < vector_size; j++) {
468             TF_RETURN_IF_ERROR(reader->ReadTensor(
469                 ctx->flr(), strings::StrCat(name, "[", i, "][", j, "]"),
470                 &group->at(i)[j]));
471           }
472         }
473         return OkStatus();
474       }
475 
StartFlushingGroup(IteratorContext * ctx,int64_t key)476       Status StartFlushingGroup(IteratorContext* ctx, int64_t key)
477           TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
478         DatasetBase* group_dataset;
479         TF_RETURN_IF_ERROR(
480             NewWindow(groups_[key], dataset()->input_->output_dtypes(),
481                       dataset()->input_->output_shapes(), &group_dataset));
482 
483         Tensor key_arg(DT_INT64, TensorShape({}));
484         key_arg.scalar<int64_t>()() = key;
485 
486         Tensor group_dataset_arg(DT_VARIANT, TensorShape({}));
487         TF_RETURN_IF_ERROR(
488             StoreDatasetInVariantTensor(group_dataset, &group_dataset_arg));
489 
490         std::vector<Tensor> args(
491             {std::move(key_arg), std::move(group_dataset_arg)});
492         std::vector<Tensor> return_values;
493         // If not restoring, pass the model node of this iterator in order to
494         // exclude captured function run time from being added to the processing
495         // time of the node. If restoring, pass nullptr to not record processing
496         // time because iterator modeling is only used to model Iterator's
497         // GetNext() resource usage.
498         TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run(
499             ctx, std::move(args), &return_values,
500             ctx->is_restoring() ? nullptr : model_node()));
501 
502         if (!(return_values.size() == 1 &&
503               return_values[0].dtype() == DT_VARIANT &&
504               TensorShapeUtils::IsScalar(return_values[0].shape()))) {
505           return errors::InvalidArgument(
506               "`reduce_func` must return a single scalar of dtype "
507               "DT_VARIANT.");
508         }
509 
510         // Retrieve the dataset that was created in `f`.
511         // `returned_dataset` is borrowed from the `return_values[0]`.
512         DatasetBase* returned_dataset;
513         TF_RETURN_IF_ERROR(
514             GetDatasetFromVariantTensor(return_values[0], &returned_dataset));
515 
516         // Create an iterator for the dataset that was returned by `f`.
517         return returned_dataset->MakeIterator(
518             MakeNestedIteratorContext(ctx), this,
519             strings::StrCat(prefix(), "[", group_counter_++, "]"),
520             &current_group_iterator_);
521       }
522 
523       mutex mu_;
524       int64_t group_counter_ TF_GUARDED_BY(mu_) = 0;
525       std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
526       // TODO(mrry): Optimize for dense key space if appropriate.
527       bool end_of_input_ TF_GUARDED_BY(mu_) = false;
528       int64_t current_key_ TF_GUARDED_BY(mu_);
529       std::map<int64_t, std::vector<std::vector<Tensor>>> groups_
530           TF_GUARDED_BY(mu_);
531       std::unique_ptr<IteratorBase> current_group_iterator_ TF_GUARDED_BY(mu_);
532       std::map<int64_t, int64_t> window_sizes_ TF_GUARDED_BY(mu_);
533       std::unique_ptr<InstantiatedCapturedFunction> instantiated_key_func_;
534       std::unique_ptr<InstantiatedCapturedFunction> instantiated_reduce_func_;
535       std::unique_ptr<InstantiatedCapturedFunction>
536           instantiated_window_size_func_;
537     };
538 
539     const DatasetBase* const input_;
540     const std::unique_ptr<CapturedFunction> captured_key_func_;
541     const std::unique_ptr<CapturedFunction> captured_reduce_func_;
542     const std::unique_ptr<CapturedFunction> captured_window_size_func_;
543     const DataTypeVector output_types_;
544     const std::vector<PartialTensorShape> output_shapes_;
545   };
546 
547   std::shared_ptr<FunctionMetadata> key_func_metadata_ = nullptr;
548   std::shared_ptr<FunctionMetadata> reduce_func_metadata_ = nullptr;
549   std::shared_ptr<FunctionMetadata> window_size_func_metadata_ = nullptr;
550   DataTypeVector output_types_;
551   std::vector<PartialTensorShape> output_shapes_;
552 };
553 
554 REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
555                         GroupByWindowDatasetOp);
556 REGISTER_KERNEL_BUILDER(
557     Name("ExperimentalGroupByWindowDataset").Device(DEVICE_CPU),
558     GroupByWindowDatasetOp);
559 
560 REGISTER_INPUT_COLOCATION_EXEMPTION("GroupByWindowDataset");
561 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByWindowDataset");
562 
563 }  // namespace
564 }  // namespace experimental
565 }  // namespace data
566 }  // namespace tensorflow
567