1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include <map> 16 17 #include "tensorflow/core/common_runtime/function.h" 18 #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" 19 #include "tensorflow/core/data/captured_function.h" 20 #include "tensorflow/core/data/dataset_utils.h" 21 #include "tensorflow/core/framework/dataset.h" 22 #include "tensorflow/core/framework/partial_tensor_shape.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/kernels/data/window_dataset.h" 25 #include "tensorflow/core/lib/random/random.h" 26 27 namespace tensorflow { 28 namespace data { 29 namespace experimental { 30 namespace { 31 32 class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { 33 public: GroupByWindowDatasetOp(OpKernelConstruction * ctx)34 explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx) 35 : UnaryDatasetOpKernel(ctx) { 36 OP_REQUIRES_OK(ctx, FunctionMetadata::Create(ctx, "key_func", /*params=*/{}, 37 &key_func_metadata_)); 38 OP_REQUIRES_OK(ctx, 39 FunctionMetadata::Create(ctx, "reduce_func", /*params=*/{}, 40 &reduce_func_metadata_)); 41 OP_REQUIRES_OK( 42 ctx, FunctionMetadata::Create(ctx, "window_size_func", /*params=*/{}, 43 &window_size_func_metadata_)); 44 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 45 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 46 } 47 MakeDataset(OpKernelContext * ctx,DatasetBase * input,DatasetBase ** output)48 void MakeDataset(OpKernelContext* ctx, DatasetBase* input, 49 DatasetBase** output) override { 50 std::unique_ptr<CapturedFunction> captured_key_func; 51 OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, key_func_metadata_, 52 "key_func_other_arguments", 53 &captured_key_func)); 54 55 std::unique_ptr<CapturedFunction> captured_reduce_func; 56 OP_REQUIRES_OK(ctx, CapturedFunction::Create(ctx, reduce_func_metadata_, 57 "reduce_func_other_arguments", 58 &captured_reduce_func)); 59 60 std::unique_ptr<CapturedFunction> captured_window_size_func; 61 OP_REQUIRES_OK(ctx, 62 CapturedFunction::Create(ctx, window_size_func_metadata_, 63 "window_size_func_other_arguments", 64 &captured_window_size_func)); 65 66 *output = new Dataset(ctx, input, std::move(captured_key_func), 67 std::move(captured_reduce_func), 68 std::move(captured_window_size_func), output_types_, 69 output_shapes_); 70 } 71 72 private: 73 class Dataset : public DatasetBase { 74 public: Dataset(OpKernelContext * ctx,const DatasetBase * input,std::unique_ptr<CapturedFunction> captured_key_func,std::unique_ptr<CapturedFunction> captured_reduce_func,std::unique_ptr<CapturedFunction> captured_window_size_func,const DataTypeVector & output_types,const std::vector<PartialTensorShape> & output_shapes)75 Dataset(OpKernelContext* ctx, const DatasetBase* input, 76 std::unique_ptr<CapturedFunction> captured_key_func, 77 std::unique_ptr<CapturedFunction> captured_reduce_func, 78 std::unique_ptr<CapturedFunction> captured_window_size_func, 79 const DataTypeVector& output_types, 80 const std::vector<PartialTensorShape>& output_shapes) 81 : DatasetBase(DatasetContext(ctx)), 82 input_(input), 83 captured_key_func_(std::move(captured_key_func)), 84 captured_reduce_func_(std::move(captured_reduce_func)), 85 captured_window_size_func_(std::move(captured_window_size_func)), 86 output_types_(output_types), 87 output_shapes_(output_shapes) { 88 input_->Ref(); 89 } 90 ~Dataset()91 ~Dataset() override { input_->Unref(); } 92 MakeIteratorInternal(const string & prefix) const93 std::unique_ptr<IteratorBase> MakeIteratorInternal( 94 const string& prefix) const override { 95 return std::make_unique<Iterator>( 96 Iterator::Params{this, strings::StrCat(prefix, "::GroupByWindow")}); 97 } 98 output_dtypes() const99 const DataTypeVector& output_dtypes() const override { 100 return output_types_; 101 } output_shapes() const102 const std::vector<PartialTensorShape>& output_shapes() const override { 103 return output_shapes_; 104 } 105 DebugString() const106 string DebugString() const override { 107 return "GroupByWindowDatasetOp::Dataset"; 108 } 109 CardinalityInternal() const110 int64_t CardinalityInternal() const override { 111 int64_t n = input_->Cardinality(); 112 if (n == kInfiniteCardinality) { 113 return n; 114 } 115 return kUnknownCardinality; 116 } 117 InputDatasets(std::vector<const DatasetBase * > * inputs) const118 Status InputDatasets( 119 std::vector<const DatasetBase*>* inputs) const override { 120 inputs->push_back(input_); 121 return OkStatus(); 122 } 123 CheckExternalState() const124 Status CheckExternalState() const override { 125 TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState()); 126 TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState()); 127 TF_RETURN_IF_ERROR(captured_window_size_func_->CheckExternalState()); 128 return input_->CheckExternalState(); 129 } 130 131 protected: AsGraphDefInternal(SerializationContext * ctx,DatasetGraphDefBuilder * b,Node ** output) const132 Status AsGraphDefInternal(SerializationContext* ctx, 133 DatasetGraphDefBuilder* b, 134 Node** output) const override { 135 Node* input_graph_node = nullptr; 136 TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); 137 138 std::vector<Node*> key_func_other_arguments_node; 139 DataTypeVector key_func_other_arguments_types; 140 TF_RETURN_IF_ERROR( 141 captured_key_func_->AddToGraph(ctx, b, &key_func_other_arguments_node, 142 &key_func_other_arguments_types)); 143 144 std::vector<Node*> reduce_func_other_arguments_node; 145 DataTypeVector reduce_func_other_arguments_types; 146 TF_RETURN_IF_ERROR(captured_reduce_func_->AddToGraph( 147 ctx, b, &reduce_func_other_arguments_node, 148 &reduce_func_other_arguments_types)); 149 150 std::vector<Node*> window_size_func_other_arguments_node; 151 DataTypeVector window_size_func_other_arguments_types; 152 TF_RETURN_IF_ERROR(captured_window_size_func_->AddToGraph( 153 ctx, b, &window_size_func_other_arguments_node, 154 &window_size_func_other_arguments_types)); 155 156 AttrValue key_func; 157 b->BuildAttrValue(captured_key_func_->func(), &key_func); 158 AttrValue reduce_func; 159 b->BuildAttrValue(captured_reduce_func_->func(), &reduce_func); 160 AttrValue window_size_func; 161 b->BuildAttrValue(captured_window_size_func_->func(), &window_size_func); 162 163 AttrValue key_func_other_arguments_types_attr; 164 b->BuildAttrValue(key_func_other_arguments_types, 165 &key_func_other_arguments_types_attr); 166 AttrValue reduce_func_other_arguments_types_attr; 167 b->BuildAttrValue(reduce_func_other_arguments_types, 168 &reduce_func_other_arguments_types_attr); 169 AttrValue window_size_func_other_arguments_types_attr; 170 b->BuildAttrValue(window_size_func_other_arguments_types, 171 &window_size_func_other_arguments_types_attr); 172 173 TF_RETURN_IF_ERROR(b->AddDataset( 174 this, {{0, input_graph_node}}, 175 {{1, key_func_other_arguments_node}, 176 {2, reduce_func_other_arguments_node}, 177 {3, window_size_func_other_arguments_node}}, 178 {{"key_func", key_func}, 179 {"reduce_func", reduce_func}, 180 {"window_size_func", window_size_func}, 181 {"Tkey_func_other_arguments", key_func_other_arguments_types_attr}, 182 {"Treduce_func_other_arguments", 183 reduce_func_other_arguments_types_attr}, 184 {"Twindow_size_func_other_arguments", 185 window_size_func_other_arguments_types_attr}}, 186 output)); 187 return OkStatus(); 188 } 189 190 private: 191 class Iterator : public DatasetIterator<Dataset> { 192 public: Iterator(const Params & params)193 explicit Iterator(const Params& params) 194 : DatasetIterator<Dataset>(params) {} 195 Initialize(IteratorContext * ctx)196 Status Initialize(IteratorContext* ctx) override { 197 TF_RETURN_IF_ERROR( 198 dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_)); 199 TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate( 200 ctx, &instantiated_key_func_)); 201 TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate( 202 ctx, &instantiated_reduce_func_)); 203 TF_RETURN_IF_ERROR(dataset()->captured_window_size_func_->Instantiate( 204 ctx, &instantiated_window_size_func_)); 205 return OkStatus(); 206 } 207 GetNextInternal(IteratorContext * ctx,std::vector<Tensor> * out_tensors,bool * end_of_sequence)208 Status GetNextInternal(IteratorContext* ctx, 209 std::vector<Tensor>* out_tensors, 210 bool* end_of_sequence) override { 211 mutex_lock l(mu_); 212 do { 213 if (current_group_iterator_) { 214 // We are currently processing a group, so try to get the 215 // next element. 216 bool end_of_group; 217 TF_RETURN_IF_ERROR(current_group_iterator_->GetNext( 218 MakeNestedIteratorContext(ctx), out_tensors, &end_of_group)); 219 if (!end_of_group) { 220 // Produce the subelement as output. 221 *end_of_sequence = false; 222 return OkStatus(); 223 } 224 // We have reached the end of the current group, so maybe move on 225 // to the next group. 226 current_group_iterator_.reset(); 227 groups_.erase(current_key_); 228 } 229 230 // Iterate through the input dataset until we get a full 231 // group, or reach the end. 232 while (!end_of_input_) { 233 std::vector<Tensor> next_input_element; 234 TF_RETURN_IF_ERROR( 235 input_impl_->GetNext(MakeNestedIteratorContext(ctx), 236 &next_input_element, &end_of_input_)); 237 238 if (!end_of_input_) { 239 // Run the key function on the input element to identify its 240 // group. 241 std::vector<Tensor> key_func_output; 242 TF_RETURN_IF_ERROR(instantiated_key_func_->RunWithBorrowedArgs( 243 ctx, next_input_element, &key_func_output, model_node())); 244 245 if (key_func_output.size() != 1 || 246 key_func_output[0].dtype() != DT_INT64 || 247 key_func_output[0].NumElements() != 1) { 248 // TODO(b/78665031): Support non-int64 keys. 249 return errors::InvalidArgument( 250 "`key_func` must return a scalar int64."); 251 } 252 const int64_t key = key_func_output[0].scalar<int64_t>()(); 253 254 if (window_sizes_.find(key) == window_sizes_.end()) { 255 // Run the window size function on the key to identify its 256 // window size. 257 std::vector<Tensor> window_size_func_output; 258 TF_RETURN_IF_ERROR(instantiated_window_size_func_->Run( 259 ctx, std::move(key_func_output), &window_size_func_output, 260 model_node())); 261 262 if (window_size_func_output.size() != 1 || 263 window_size_func_output[0].dtype() != DT_INT64 || 264 window_size_func_output[0].NumElements() != 1) { 265 // TODO(mrry): Support non-int64 window sizes. 266 return errors::InvalidArgument( 267 "`window_size_func` must return a scalar int64."); 268 } 269 const int64_t window_size = 270 window_size_func_output[0].scalar<int64_t>()(); 271 if (window_size <= 0) { 272 return errors::InvalidArgument( 273 "Window size must be greater than zero, but got ", 274 window_size, "."); 275 } 276 window_sizes_[key] = window_size; 277 } 278 279 const int64_t window_size = window_sizes_[key]; 280 281 std::vector<std::vector<Tensor>>& group = groups_[key]; 282 group.push_back(std::move(next_input_element)); 283 284 if (group.size() == window_size) { 285 current_key_ = key; 286 TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, key)); 287 break; 288 } 289 } 290 } 291 292 if (end_of_input_) { 293 if (!groups_.empty()) { 294 // We have consumed all of the input, so flush an 295 // arbitrarily chosen group. 296 current_key_ = groups_.begin()->first; 297 TF_RETURN_IF_ERROR( 298 StartFlushingGroup(ctx, groups_.begin()->first)); 299 } 300 } 301 } while (current_group_iterator_ || !end_of_input_); 302 303 *end_of_sequence = true; 304 return OkStatus(); 305 } 306 307 protected: CreateNode(IteratorContext * ctx,model::Node::Args args) const308 std::shared_ptr<model::Node> CreateNode( 309 IteratorContext* ctx, model::Node::Args args) const override { 310 return model::MakeUnknownRatioNode(std::move(args)); 311 } 312 SaveInternal(SerializationContext * ctx,IteratorStateWriter * writer)313 Status SaveInternal(SerializationContext* ctx, 314 IteratorStateWriter* writer) override { 315 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( 316 dataset()->captured_key_func_->CheckExternalState())); 317 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( 318 dataset()->captured_reduce_func_->CheckExternalState())); 319 TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus( 320 dataset()->captured_window_size_func_->CheckExternalState())); 321 mutex_lock l(mu_); 322 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_)); 323 324 if (end_of_input_) { 325 TF_RETURN_IF_ERROR( 326 writer->WriteScalar(full_name("end_of_input"), "")); 327 } 328 329 // Saving groups_ 330 if (!groups_.empty()) { 331 TF_RETURN_IF_ERROR( 332 writer->WriteScalar(full_name("groups_size"), groups_.size())); 333 int idx = 0; 334 for (auto it = groups_.begin(); it != groups_.end(); it++) { 335 int64_t key = it->first; 336 TF_RETURN_IF_ERROR(writer->WriteScalar( 337 full_name(strings::StrCat("groups_[", idx, "]->key")), key)); 338 TF_RETURN_IF_ERROR(SaveGroup( 339 writer, full_name(strings::StrCat("groups_[", idx, "]")), 340 it->second)); 341 idx++; 342 } 343 } 344 345 // Saving window_sizes_ 346 if (!window_sizes_.empty()) { 347 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("window_sizes_size"), 348 window_sizes_.size())); 349 int idx = 0; 350 for (auto it = window_sizes_.begin(); it != window_sizes_.end(); 351 it++) { 352 TF_RETURN_IF_ERROR(writer->WriteScalar( 353 full_name(strings::StrCat("window_sizes_[", idx, "]->key")), 354 it->first)); 355 TF_RETURN_IF_ERROR(writer->WriteScalar( 356 full_name(strings::StrCat("window_sizes_[", idx, "]->value")), 357 it->second)); 358 idx++; 359 } 360 } 361 362 if (current_group_iterator_) { 363 TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_group_iterator_)); 364 365 // Saving current_key_ 366 TF_RETURN_IF_ERROR( 367 writer->WriteScalar(full_name("current_key"), current_key_)); 368 } else { 369 TF_RETURN_IF_ERROR(writer->WriteScalar( 370 full_name("current_iterator_not_initialized"), "")); 371 } 372 TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("group_counter"), 373 group_counter_ - 1)); 374 return OkStatus(); 375 } 376 RestoreInternal(IteratorContext * ctx,IteratorStateReader * reader)377 Status RestoreInternal(IteratorContext* ctx, 378 IteratorStateReader* reader) override { 379 mutex_lock l(mu_); 380 TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); 381 382 if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; 383 384 // Restoring groups_ 385 if (reader->Contains(full_name("groups_size"))) { 386 int64_t size; 387 TF_RETURN_IF_ERROR( 388 reader->ReadScalar(full_name("groups_size"), &size)); 389 for (int idx = 0; idx < size; idx++) { 390 int64_t key; 391 TF_RETURN_IF_ERROR(reader->ReadScalar( 392 full_name(strings::StrCat("groups_[", idx, "]->key")), &key)); 393 std::vector<std::vector<Tensor>> group; 394 TF_RETURN_IF_ERROR(RestoreGroup( 395 ctx, reader, full_name(strings::StrCat("groups_[", idx, "]")), 396 &group)); 397 groups_[key] = group; 398 } 399 } 400 401 // Restoring window_sizes_ 402 if (reader->Contains(full_name("window_sizes_size"))) { 403 int64_t size; 404 TF_RETURN_IF_ERROR( 405 reader->ReadScalar(full_name("window_sizes_size"), &size)); 406 for (int idx = 0; idx < size; idx++) { 407 int64_t key; 408 TF_RETURN_IF_ERROR(reader->ReadScalar( 409 full_name(strings::StrCat("window_sizes_[", idx, "]->key")), 410 &key)); 411 TF_RETURN_IF_ERROR(reader->ReadScalar( 412 full_name(strings::StrCat("window_sizes_[", idx, "]->value")), 413 &window_sizes_[key])); 414 } 415 } 416 417 // Group counter needs to be restored before current group iterator. 418 TF_RETURN_IF_ERROR( 419 reader->ReadScalar(full_name("group_counter"), &group_counter_)); 420 421 if (reader->Contains(full_name("current_iterator_not_initialized"))) { 422 current_group_iterator_.reset(); 423 } else { 424 // Restore current_key_ 425 TF_RETURN_IF_ERROR( 426 reader->ReadScalar(full_name("current_key"), ¤t_key_)); 427 428 // Initialize current_group_iterator_ 429 TF_RETURN_IF_ERROR(StartFlushingGroup(ctx, current_key_)); 430 // Restore current_group_iterator_ state 431 TF_RETURN_IF_ERROR( 432 RestoreInput(ctx, reader, current_group_iterator_)); 433 } 434 return OkStatus(); 435 } 436 437 private: SaveGroup(IteratorStateWriter * writer,const string & name,const std::vector<std::vector<Tensor>> & group)438 Status SaveGroup(IteratorStateWriter* writer, const string& name, 439 const std::vector<std::vector<Tensor>>& group) 440 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 441 TF_RETURN_IF_ERROR( 442 writer->WriteScalar(strings::StrCat(name, "_size"), group.size())); 443 for (int i = 0; i < group.size(); i++) { 444 TF_RETURN_IF_ERROR(writer->WriteScalar( 445 strings::StrCat(name, "[", i, "]_size"), group[i].size())); 446 for (int j = 0; j < group[i].size(); j++) { 447 TF_RETURN_IF_ERROR(writer->WriteTensor( 448 strings::StrCat(name, "[", i, "][", j, "]"), group[i][j])); 449 } 450 } 451 return OkStatus(); 452 } 453 RestoreGroup(IteratorContext * ctx,IteratorStateReader * reader,const string & name,std::vector<std::vector<Tensor>> * group)454 Status RestoreGroup(IteratorContext* ctx, IteratorStateReader* reader, 455 const string& name, 456 std::vector<std::vector<Tensor>>* group) 457 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 458 int64_t group_size; 459 TF_RETURN_IF_ERROR( 460 reader->ReadScalar(strings::StrCat(name, "_size"), &group_size)); 461 group->resize(group_size); 462 for (int i = 0; i < group_size; i++) { 463 int64_t vector_size; 464 TF_RETURN_IF_ERROR(reader->ReadScalar( 465 strings::StrCat(name, "[", i, "]_size"), &vector_size)); 466 group->at(i).resize(vector_size); 467 for (int j = 0; j < vector_size; j++) { 468 TF_RETURN_IF_ERROR(reader->ReadTensor( 469 ctx->flr(), strings::StrCat(name, "[", i, "][", j, "]"), 470 &group->at(i)[j])); 471 } 472 } 473 return OkStatus(); 474 } 475 StartFlushingGroup(IteratorContext * ctx,int64_t key)476 Status StartFlushingGroup(IteratorContext* ctx, int64_t key) 477 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 478 DatasetBase* group_dataset; 479 TF_RETURN_IF_ERROR( 480 NewWindow(groups_[key], dataset()->input_->output_dtypes(), 481 dataset()->input_->output_shapes(), &group_dataset)); 482 483 Tensor key_arg(DT_INT64, TensorShape({})); 484 key_arg.scalar<int64_t>()() = key; 485 486 Tensor group_dataset_arg(DT_VARIANT, TensorShape({})); 487 TF_RETURN_IF_ERROR( 488 StoreDatasetInVariantTensor(group_dataset, &group_dataset_arg)); 489 490 std::vector<Tensor> args( 491 {std::move(key_arg), std::move(group_dataset_arg)}); 492 std::vector<Tensor> return_values; 493 // If not restoring, pass the model node of this iterator in order to 494 // exclude captured function run time from being added to the processing 495 // time of the node. If restoring, pass nullptr to not record processing 496 // time because iterator modeling is only used to model Iterator's 497 // GetNext() resource usage. 498 TF_RETURN_IF_ERROR(instantiated_reduce_func_->Run( 499 ctx, std::move(args), &return_values, 500 ctx->is_restoring() ? nullptr : model_node())); 501 502 if (!(return_values.size() == 1 && 503 return_values[0].dtype() == DT_VARIANT && 504 TensorShapeUtils::IsScalar(return_values[0].shape()))) { 505 return errors::InvalidArgument( 506 "`reduce_func` must return a single scalar of dtype " 507 "DT_VARIANT."); 508 } 509 510 // Retrieve the dataset that was created in `f`. 511 // `returned_dataset` is borrowed from the `return_values[0]`. 512 DatasetBase* returned_dataset; 513 TF_RETURN_IF_ERROR( 514 GetDatasetFromVariantTensor(return_values[0], &returned_dataset)); 515 516 // Create an iterator for the dataset that was returned by `f`. 517 return returned_dataset->MakeIterator( 518 MakeNestedIteratorContext(ctx), this, 519 strings::StrCat(prefix(), "[", group_counter_++, "]"), 520 ¤t_group_iterator_); 521 } 522 523 mutex mu_; 524 int64_t group_counter_ TF_GUARDED_BY(mu_) = 0; 525 std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_); 526 // TODO(mrry): Optimize for dense key space if appropriate. 527 bool end_of_input_ TF_GUARDED_BY(mu_) = false; 528 int64_t current_key_ TF_GUARDED_BY(mu_); 529 std::map<int64_t, std::vector<std::vector<Tensor>>> groups_ 530 TF_GUARDED_BY(mu_); 531 std::unique_ptr<IteratorBase> current_group_iterator_ TF_GUARDED_BY(mu_); 532 std::map<int64_t, int64_t> window_sizes_ TF_GUARDED_BY(mu_); 533 std::unique_ptr<InstantiatedCapturedFunction> instantiated_key_func_; 534 std::unique_ptr<InstantiatedCapturedFunction> instantiated_reduce_func_; 535 std::unique_ptr<InstantiatedCapturedFunction> 536 instantiated_window_size_func_; 537 }; 538 539 const DatasetBase* const input_; 540 const std::unique_ptr<CapturedFunction> captured_key_func_; 541 const std::unique_ptr<CapturedFunction> captured_reduce_func_; 542 const std::unique_ptr<CapturedFunction> captured_window_size_func_; 543 const DataTypeVector output_types_; 544 const std::vector<PartialTensorShape> output_shapes_; 545 }; 546 547 std::shared_ptr<FunctionMetadata> key_func_metadata_ = nullptr; 548 std::shared_ptr<FunctionMetadata> reduce_func_metadata_ = nullptr; 549 std::shared_ptr<FunctionMetadata> window_size_func_metadata_ = nullptr; 550 DataTypeVector output_types_; 551 std::vector<PartialTensorShape> output_shapes_; 552 }; 553 554 REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU), 555 GroupByWindowDatasetOp); 556 REGISTER_KERNEL_BUILDER( 557 Name("ExperimentalGroupByWindowDataset").Device(DEVICE_CPU), 558 GroupByWindowDatasetOp); 559 560 REGISTER_INPUT_COLOCATION_EXEMPTION("GroupByWindowDataset"); 561 REGISTER_INPUT_COLOCATION_EXEMPTION("ExperimentalGroupByWindowDataset"); 562 563 } // namespace 564 } // namespace experimental 565 } // namespace data 566 } // namespace tensorflow 567