1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/full_type.pb.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/op_def_builder.h" 19 #include "tensorflow/core/framework/shape_inference.h" 20 21 namespace tensorflow { 22 23 // -------------------------------------------------------------------------- 24 25 // The ops in this section can be composed to define an input 26 // pipeline. Each op produces a DT_VARIANT tensor that represents 27 // a DAG of "dataset" objects. An "dataset" object can be converted 28 // to a stateful "iterator" by passing the "dataset" to the 29 // "MakeIterator" op. 30 // 31 // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are 32 // not presently serializable. To avoid issues with graph optimizations, such 33 // as constant folding, CSE, or DCE, ensure that any "source dataset" ops 34 // (i.e. ops that output a dataset and do not take one as input) are 35 // marked as "do not optimize". 36 37 // TODO(mrry): Validate that `components` have shapes compatible with 38 // `output_shapes`. 39 REGISTER_OP("TensorDataset") 40 .Input("components: Toutput_types") 41 .Output("handle: variant") 42 .Attr("Toutput_types: list(type) >= 1") 43 .Attr("output_shapes: list(shape) >= 1") 44 .Attr("metadata: string = ''") 45 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 46 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 47 "Toutput_types")) 48 .SetShapeFn(shape_inference::ScalarShape); 49 50 // TODO(mrry): Validate that the dim-0 slices of `components` have shapes 51 // compatible with `output_shapes`. 52 REGISTER_OP("TensorSliceDataset") 53 .Input("components: Toutput_types") 54 .Output("handle: variant") 55 .Attr("Toutput_types: list(type) >= 1") 56 .Attr("output_shapes: list(shape) >= 1") 57 .Attr("is_files: bool = false") 58 .Attr("metadata: string = ''") 59 .Attr("replicate_on_split: bool = false") 60 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 61 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 62 "Toutput_types")) 63 .SetForwardTypeFn(full_type::MultiaryUnstack(TFT_DATASET, 64 full_type::UnstackTensor)) 65 .SetShapeFn(shape_inference::ScalarShape); 66 67 REGISTER_OP("SparseTensorSliceDataset") 68 .Input("indices: int64") 69 .Input("values: Tvalues") 70 .Input("dense_shape: int64") 71 .Output("handle: variant") 72 .Attr("Tvalues: type") 73 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 74 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, "Tvalues")) 75 .SetShapeFn(shape_inference::ScalarShape); 76 77 REGISTER_OP("GeneratorDataset") 78 .Input("init_func_other_args: Tinit_func_args") 79 .Input("next_func_other_args: Tnext_func_args") 80 .Input("finalize_func_other_args: Tfinalize_func_args") 81 .Output("handle: variant") 82 .Attr("init_func: func") 83 .Attr("next_func: func") 84 .Attr("finalize_func: func") 85 .Attr("Tinit_func_args: list(type) >= 0") 86 .Attr("Tnext_func_args: list(type) >= 0") 87 .Attr("Tfinalize_func_args: list(type) >= 0") 88 .Attr("output_types: list(type) >= 1") 89 .Attr("output_shapes: list(shape) >= 1") 90 .Attr("metadata: string = ''") 91 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 92 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 93 "output_types")) 94 .SetShapeFn(shape_inference::ScalarShape); 95 96 REGISTER_OP("ZipDataset") 97 .Input("input_datasets: N * variant") 98 .Output("handle: variant") 99 .Attr("output_types: list(type) >= 1") 100 .Attr("output_shapes: list(shape) >= 1") 101 .Attr("N: int >= 1") 102 .Attr("metadata: string = ''") 103 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 104 "output_types")) 105 .SetShapeFn(shape_inference::ScalarShape); 106 107 REGISTER_OP("ConcatenateDataset") 108 .Input("input_dataset: variant") 109 .Input("another_dataset: variant") 110 .Output("handle: variant") 111 .Attr("output_types: list(type) >= 1") 112 .Attr("output_shapes: list(shape) >= 1") 113 .Attr("metadata: string = ''") 114 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 115 "output_types")) 116 .SetShapeFn(shape_inference::ScalarShape); 117 118 REGISTER_OP("RepeatDataset") 119 .Input("input_dataset: variant") 120 .Input("count: int64") 121 .Output("handle: variant") 122 .Attr("output_types: list(type) >= 1") 123 .Attr("output_shapes: list(shape) >= 1") 124 .Attr("metadata: string = ''") 125 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 126 "output_types")) __anon4377504f0102(shape_inference::InferenceContext* c) 127 .SetShapeFn([](shape_inference::InferenceContext* c) { 128 shape_inference::ShapeHandle count_shape; 129 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 130 return shape_inference::ScalarShape(c); 131 }); 132 133 REGISTER_OP("TakeDataset") 134 .Input("input_dataset: variant") 135 .Input("count: int64") 136 .Output("handle: variant") 137 .Attr("output_types: list(type) >= 1") 138 .Attr("output_shapes: list(shape) >= 1") 139 .Attr("metadata: string = ''") 140 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 141 "output_types")) __anon4377504f0202(shape_inference::InferenceContext* c) 142 .SetShapeFn([](shape_inference::InferenceContext* c) { 143 shape_inference::ShapeHandle count_shape; 144 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 145 return shape_inference::ScalarShape(c); 146 }); 147 148 REGISTER_OP("SkipDataset") 149 .Input("input_dataset: variant") 150 .Input("count: int64") 151 .Output("handle: variant") 152 .Attr("output_types: list(type) >= 1") 153 .Attr("output_shapes: list(shape) >= 1") 154 .Attr("metadata: string = ''") 155 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 156 "output_types")) __anon4377504f0302(shape_inference::InferenceContext* c) 157 .SetShapeFn([](shape_inference::InferenceContext* c) { 158 shape_inference::ShapeHandle count_shape; 159 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape)); 160 return shape_inference::ScalarShape(c); 161 }); 162 163 REGISTER_OP("MapDataset") 164 .Input("input_dataset: variant") 165 .Input("other_arguments: Targuments") 166 .Output("handle: variant") 167 .Attr("f: func") 168 .Attr("Targuments: list(type) >= 0") 169 .Attr("output_types: list(type) >= 1") 170 .Attr("output_shapes: list(shape) >= 1") 171 .Attr("use_inter_op_parallelism: bool = true") 172 .Attr("preserve_cardinality: bool = false") 173 .Attr("metadata: string = ''") 174 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 175 "output_types")) 176 .SetShapeFn(shape_inference::ScalarShape); 177 178 REGISTER_OP("ParallelMapDataset") 179 .Input("input_dataset: variant") 180 .Input("other_arguments: Targuments") 181 .Input("num_parallel_calls: int32") 182 .Output("handle: variant") 183 .Attr("f: func") 184 .Attr("Targuments: list(type) >= 0") 185 .Attr("output_types: list(type) >= 1") 186 .Attr("output_shapes: list(shape) >= 1") 187 .Attr("use_inter_op_parallelism: bool = true") 188 .Attr("sloppy: bool = false") 189 .Attr("preserve_cardinality: bool = false") 190 .Attr("metadata: string = ''") 191 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 192 "output_types")) 193 .SetShapeFn(shape_inference::ScalarShape); 194 195 REGISTER_OP("ParallelMapDatasetV2") 196 .Input("input_dataset: variant") 197 .Input("other_arguments: Targuments") 198 .Input("num_parallel_calls: int64") 199 .Output("handle: variant") 200 .Attr("f: func") 201 .Attr("Targuments: list(type) >= 0") 202 .Attr("output_types: list(type) >= 1") 203 .Attr("output_shapes: list(shape) >= 1") 204 .Attr("use_inter_op_parallelism: bool = true") 205 // "true", "false", or "default". 206 .Attr("deterministic: string = 'default'") 207 .Attr("preserve_cardinality: bool = false") 208 .Attr("metadata: string = ''") 209 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 210 "output_types")) 211 .SetShapeFn(shape_inference::ScalarShape); 212 213 REGISTER_OP("PrefetchDataset") 214 .Input("input_dataset: variant") 215 .Input("buffer_size: int64") 216 .Output("handle: variant") 217 .Attr("output_types: list(type) >= 1") 218 .Attr("output_shapes: list(shape) >= 1") 219 .Attr("slack_period: int = 0") 220 .Attr("legacy_autotune: bool = true") 221 .Attr("buffer_size_min: int = 0") 222 .Attr("metadata: string = ''") 223 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 224 "output_types")) __anon4377504f0402(shape_inference::InferenceContext* c) 225 .SetShapeFn([](shape_inference::InferenceContext* c) { 226 shape_inference::ShapeHandle unused; 227 // buffer_size should be a scalar. 228 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 229 return shape_inference::ScalarShape(c); 230 }); 231 232 REGISTER_OP("FlatMapDataset") 233 .Input("input_dataset: variant") 234 .Input("other_arguments: Targuments") 235 .Output("handle: variant") 236 .Attr("f: func") 237 .Attr("Targuments: list(type) >= 0") 238 .Attr("output_types: list(type) >= 1") 239 .Attr("output_shapes: list(shape) >= 1") 240 .Attr("metadata: string = ''") 241 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 242 "output_types")) 243 .SetShapeFn(shape_inference::ScalarShape); 244 245 REGISTER_OP("InterleaveDataset") 246 .Input("input_dataset: variant") 247 .Input("other_arguments: Targuments") 248 .Input("cycle_length: int64") 249 .Input("block_length: int64") 250 .Output("handle: variant") 251 .Attr("f: func") 252 .Attr("Targuments: list(type) >= 0") 253 .Attr("output_types: list(type) >= 1") 254 .Attr("output_shapes: list(shape) >= 1") 255 .Attr("metadata: string = ''") 256 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 257 "output_types")) 258 .SetShapeFn(shape_inference::ScalarShape); 259 260 REGISTER_OP("ParallelInterleaveDatasetV2") 261 .Input("input_dataset: variant") 262 .Input("other_arguments: Targuments") 263 .Input("cycle_length: int64") 264 .Input("block_length: int64") 265 .Input("num_parallel_calls: int64") 266 .Output("handle: variant") 267 .Attr("f: func") 268 .Attr("Targuments: list(type) >= 0") 269 .Attr("output_types: list(type) >= 1") 270 .Attr("output_shapes: list(shape) >= 1") 271 .Attr("sloppy: bool = false") 272 .Attr("metadata: string = ''") 273 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 274 "output_types")) 275 .SetShapeFn(shape_inference::ScalarShape); 276 277 REGISTER_OP("ParallelInterleaveDatasetV3") 278 .Input("input_dataset: variant") 279 .Input("other_arguments: Targuments") 280 .Input("cycle_length: int64") 281 .Input("block_length: int64") 282 .Input("num_parallel_calls: int64") 283 .Output("handle: variant") 284 .Attr("f: func") 285 // "true", "false", or "default". 286 .Attr("deterministic: string = 'default'") 287 .Attr("Targuments: list(type) >= 0") 288 .Attr("output_types: list(type) >= 1") 289 .Attr("output_shapes: list(shape) >= 1") 290 .Attr("metadata: string = ''") 291 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 292 "output_types")) 293 .SetShapeFn(shape_inference::ScalarShape); 294 295 // Like V3, but adds buffer_output_elements and prefetch_input_elements. 296 REGISTER_OP("ParallelInterleaveDatasetV4") 297 .Input("input_dataset: variant") 298 .Input("other_arguments: Targuments") 299 .Input("cycle_length: int64") 300 .Input("block_length: int64") 301 .Input("buffer_output_elements: int64") 302 .Input("prefetch_input_elements: int64") 303 .Input("num_parallel_calls: int64") 304 .Output("handle: variant") 305 .Attr("f: func") 306 // "true", "false", or "default". 307 .Attr("deterministic: string = 'default'") 308 .Attr("Targuments: list(type) >= 0") 309 .Attr("output_types: list(type) >= 1") 310 .Attr("output_shapes: list(shape) >= 1") 311 .Attr("metadata: string = ''") 312 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 313 "output_types")) 314 .SetShapeFn(shape_inference::ScalarShape); 315 316 REGISTER_OP("FilterDataset") 317 .Input("input_dataset: variant") 318 .Input("other_arguments: Targuments") 319 .Output("handle: variant") 320 .Attr("predicate: func") 321 .Attr("Targuments: list(type) >= 0") 322 .Attr("output_types: list(type) >= 1") 323 .Attr("output_shapes: list(shape) >= 1") 324 .Attr("metadata: string = ''") 325 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 326 "output_types")) 327 .SetShapeFn(shape_inference::ScalarShape); 328 329 REGISTER_OP("ParallelFilterDataset") 330 .Input("input_dataset: variant") 331 .Input("other_arguments: Targuments") 332 .Input("num_parallel_calls: int64") 333 .Output("handle: variant") 334 .Attr("predicate: func") 335 // "true", "false", or "default". 336 .Attr("deterministic: string = 'default'") 337 .Attr("Targuments: list(type) >= 0") 338 .Attr("output_types: list(type) >= 1") 339 .Attr("output_shapes: list(shape) >= 1") 340 .Attr("metadata: string = ''") 341 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 342 "output_types")) 343 .SetShapeFn(shape_inference::ScalarShape); 344 345 // This op is no longer supported. 346 REGISTER_OP("FilterByLastComponentDataset") 347 .Input("input_dataset: variant") 348 .Output("output: variant") 349 .Attr("output_types: list(type) >= 1") 350 .Attr("output_shapes: list(shape) >= 1") 351 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 352 "output_types")) 353 .SetShapeFn(shape_inference::ScalarShape); 354 355 REGISTER_OP("WindowDataset") 356 .Input("input_dataset: variant") 357 .Input("size: int64") 358 .Input("shift: int64") 359 .Input("stride: int64") 360 .Input("drop_remainder: bool") 361 .Output("handle: variant") 362 .Attr("output_types: list(type) >= 1") 363 .Attr("output_shapes: list(shape) >= 1") 364 .Attr("metadata: string = ''") 365 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 366 "output_types")) __anon4377504f0502(shape_inference::InferenceContext* c) 367 .SetShapeFn([](shape_inference::InferenceContext* c) { 368 shape_inference::ShapeHandle unused; 369 // size, shift, stride, and drop_remainder should be scalars. 370 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 371 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 372 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 373 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 374 return shape_inference::ScalarShape(c); 375 }); 376 377 REGISTER_OP("WindowOp") 378 .Input("inputs: Tinputs") 379 .Output("handle: variant") 380 .Attr("output_types: list(type) >= 1") 381 .Attr("output_shapes: list(shape) >= 1") 382 .Attr("Tinputs: list(type) >= 1") 383 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 384 "output_types")) 385 .SetShapeFn(shape_inference::ScalarShape); 386 387 REGISTER_OP("BatchDataset") 388 .Input("input_dataset: variant") 389 .Input("batch_size: int64") 390 .Output("handle: variant") 391 .Attr("output_types: list(type) >= 1") 392 .Attr("output_shapes: list(shape) >= 1") 393 .Attr("metadata: string = ''") 394 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 395 "output_types")) __anon4377504f0602(shape_inference::InferenceContext* c) 396 .SetShapeFn([](shape_inference::InferenceContext* c) { 397 shape_inference::ShapeHandle unused; 398 // batch_size should be a scalar. 399 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 400 return shape_inference::ScalarShape(c); 401 }); 402 403 REGISTER_OP("BatchDatasetV2") 404 .Input("input_dataset: variant") 405 .Input("batch_size: int64") 406 .Input("drop_remainder: bool") 407 .Output("handle: variant") 408 .Attr("parallel_copy: bool = false") 409 .Attr("output_types: list(type) >= 1") 410 .Attr("output_shapes: list(shape) >= 1") 411 .Attr("metadata: string = ''") 412 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 413 "output_types")) 414 .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0, 415 full_type::BatchTensor)) __anon4377504f0702(shape_inference::InferenceContext* c) 416 .SetShapeFn([](shape_inference::InferenceContext* c) { 417 shape_inference::ShapeHandle unused; 418 // batch_size should be a scalar. 419 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 420 // drop_remainder should be a scalar. 421 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 422 return shape_inference::ScalarShape(c); 423 }); 424 425 REGISTER_OP("ParallelBatchDataset") 426 .Input("input_dataset: variant") 427 .Input("batch_size: int64") 428 .Input("num_parallel_calls: int64") 429 .Input("drop_remainder: bool") 430 .Output("handle: variant") 431 .Attr("parallel_copy: bool = false") 432 .Attr("output_types: list(type) >= 1") 433 .Attr("output_shapes: list(shape) >= 1") 434 // "true", "false", or "default". 435 .Attr("deterministic: string = 'default'") 436 .Attr("metadata: string = ''") 437 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 438 "output_types")) __anon4377504f0802(shape_inference::InferenceContext* c) 439 .SetShapeFn([](shape_inference::InferenceContext* c) { 440 shape_inference::ShapeHandle unused; 441 // batch_size should be a scalar. 442 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 443 // num_parallel_calls should be a scalar. 444 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 445 // drop_remainder should be a scalar. 446 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 447 return shape_inference::ScalarShape(c); 448 }); 449 450 REGISTER_OP("ShardDataset") 451 .Input("input_dataset: variant") 452 .Input("num_shards: int64") 453 .Input("index: int64") 454 .Output("handle: variant") 455 .Attr("require_non_empty: bool = false") 456 .Attr("output_types: list(type) >= 1") 457 .Attr("output_shapes: list(shape) >= 1") 458 .Attr("metadata: string = ''") 459 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 460 "output_types")) __anon4377504f0902(shape_inference::InferenceContext* c) 461 .SetShapeFn([](shape_inference::InferenceContext* c) { 462 shape_inference::ShapeHandle unused; 463 // num_shards should be a scalar. 464 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 465 // index should be a scalar. 466 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 467 return shape_inference::ScalarShape(c); 468 }); 469 470 // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of 471 // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as 472 // possible to tell statically) compatible with `padded_shapes`, and that 473 // `padding_values` are all scalars. 474 REGISTER_OP("PaddedBatchDataset") 475 .Input("input_dataset: variant") 476 .Input("batch_size: int64") 477 .Input("padded_shapes: N * int64") 478 .Input("padding_values: Toutput_types") 479 .Output("handle: variant") 480 .Attr("Toutput_types: list(type) >= 1") 481 .Attr("output_shapes: list(shape) >= 1") 482 .Attr("N: int >= 1") 483 .Attr("metadata: string = ''") 484 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 485 "Toutput_types")) __anon4377504f0a02(shape_inference::InferenceContext* c) 486 .SetShapeFn([](shape_inference::InferenceContext* c) { 487 shape_inference::ShapeHandle unused; 488 // batch_size should be a scalar. 489 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 490 return shape_inference::ScalarShape(c); 491 }); 492 493 REGISTER_OP("PaddedBatchDatasetV2") 494 .Input("input_dataset: variant") 495 .Input("batch_size: int64") 496 .Input("padded_shapes: N * int64") 497 .Input("padding_values: Toutput_types") 498 .Input("drop_remainder: bool") 499 .Output("handle: variant") 500 .Attr("parallel_copy: bool = false") 501 .Attr("Toutput_types: list(type) >= 1") 502 .Attr("output_shapes: list(shape) >= 1") 503 .Attr("N: int >= 1") 504 .Attr("metadata: string = ''") 505 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 506 "Toutput_types")) __anon4377504f0b02(shape_inference::InferenceContext* c) 507 .SetShapeFn([](shape_inference::InferenceContext* c) { 508 shape_inference::ShapeHandle unused; 509 // batch_size should be a scalar. 510 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 511 // drop_remainder should be a scalar. 512 TF_RETURN_IF_ERROR( 513 c->WithRank(c->input(c->num_inputs() - 1), 0, &unused)); 514 return shape_inference::ScalarShape(c); 515 }); 516 517 REGISTER_OP("RangeDataset") 518 .Input("start: int64") 519 .Input("stop: int64") 520 .Input("step: int64") 521 .Output("handle: variant") 522 .Attr("output_types: list(type) >= 1") 523 .Attr("output_shapes: list(shape) >= 1") 524 .Attr("metadata: string = ''") 525 .Attr("replicate_on_split: bool = false") 526 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 527 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 528 "output_types")) __anon4377504f0c02(shape_inference::InferenceContext* c) 529 .SetShapeFn([](shape_inference::InferenceContext* c) { 530 shape_inference::ShapeHandle unused; 531 // start, stop, and step should be scalars. 532 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 533 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 534 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 535 return shape_inference::ScalarShape(c); 536 }); 537 538 REGISTER_OP("RewriteDataset") 539 .Input("input_dataset: variant") 540 .Input("rewrite_name: string") 541 .Output("handle: variant") 542 .Attr("output_types: list(type) >= 1") 543 .Attr("output_shapes: list(shape) >= 1") 544 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 545 "output_types")) 546 .SetShapeFn(shape_inference::ScalarShape); 547 548 REGISTER_OP("AnonymousSeedGenerator") 549 .Input("seed: int64") 550 .Input("seed2: int64") 551 .Input("reshuffle: bool") 552 .Output("handle: resource") 553 .Output("deleter: variant") __anon4377504f0d02(shape_inference::InferenceContext* c) 554 .SetShapeFn([](shape_inference::InferenceContext* c) { 555 c->set_output(0, c->Scalar()); 556 c->set_output(1, c->Scalar()); 557 return OkStatus(); 558 }); 559 560 REGISTER_OP("DatasetCardinality") 561 .Input("input_dataset: variant") 562 .Output("cardinality: int64") 563 .SetShapeFn(shape_inference::ScalarShape); 564 565 REGISTER_OP("DeleteSeedGenerator") 566 .Input("handle: resource") 567 .Input("deleter: variant") 568 .SetShapeFn(shape_inference::NoOutputs); 569 570 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. 571 REGISTER_OP("AnonymousRandomSeedGenerator") 572 .Input("seed: int64") 573 .Input("seed2: int64") 574 .Output("handle: resource") 575 .Output("deleter: variant") __anon4377504f0e02(shape_inference::InferenceContext* c) 576 .SetShapeFn([](shape_inference::InferenceContext* c) { 577 c->set_output(0, c->Scalar()); 578 c->set_output(1, c->Scalar()); 579 return OkStatus(); 580 }); 581 582 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator. 583 REGISTER_OP("DeleteRandomSeedGenerator") 584 .Input("handle: resource") 585 .Input("deleter: variant") 586 .SetShapeFn(shape_inference::NoOutputs); 587 588 REGISTER_OP("DummySeedGenerator") 589 .Output("handle: resource") __anon4377504f0f02(shape_inference::InferenceContext* c) 590 .SetShapeFn([](shape_inference::InferenceContext* c) { 591 c->set_output(0, c->Scalar()); 592 return OkStatus(); 593 }); 594 595 REGISTER_OP("ShuffleDataset") 596 .Input("input_dataset: variant") 597 .Input("buffer_size: int64") 598 .Input("seed: int64") 599 .Input("seed2: int64") 600 .Output("handle: variant") 601 .Attr("reshuffle_each_iteration: bool = true") 602 .Attr("output_types: list(type) >= 1") 603 .Attr("output_shapes: list(shape) >= 1") 604 .Attr("metadata: string = ''") 605 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 606 "output_types")) __anon4377504f1002(shape_inference::InferenceContext* c) 607 .SetShapeFn([](shape_inference::InferenceContext* c) { 608 shape_inference::ShapeHandle unused; 609 // buffer_size, seed, and seed2 should be scalars. 610 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 611 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 612 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 613 return shape_inference::ScalarShape(c); 614 }); 615 616 REGISTER_OP("ShuffleDatasetV2") 617 .Input("input_dataset: variant") 618 .Input("buffer_size: int64") 619 .Input("seed_generator: resource") 620 .Output("handle: variant") 621 .Attr("output_types: list(type) >= 1") 622 .Attr("output_shapes: list(shape) >= 1") 623 .Attr("metadata: string = ''") 624 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 625 "output_types")) __anon4377504f1102(shape_inference::InferenceContext* c) 626 .SetShapeFn([](shape_inference::InferenceContext* c) { 627 shape_inference::ShapeHandle unused; 628 // buffer_size and seed_generator should be scalars. 629 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 630 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 631 return shape_inference::ScalarShape(c); 632 }); 633 634 REGISTER_OP("ShuffleDatasetV3") 635 .Input("input_dataset: variant") 636 .Input("buffer_size: int64") 637 .Input("seed: int64") 638 .Input("seed2: int64") 639 .Input("seed_generator: resource") 640 .Output("handle: variant") 641 .Attr("reshuffle_each_iteration: bool = true") 642 .Attr("output_types: list(type) >= 1") 643 .Attr("output_shapes: list(shape) >= 1") 644 .Attr("metadata: string = ''") 645 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 646 "output_types")) __anon4377504f1202(shape_inference::InferenceContext* c) 647 .SetShapeFn([](shape_inference::InferenceContext* c) { 648 shape_inference::ShapeHandle unused; 649 // buffer_size, seed, seed2, and seed_generator should be scalars. 650 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 651 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 652 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 653 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 654 return shape_inference::ScalarShape(c); 655 }); 656 657 REGISTER_OP("ShuffleAndRepeatDataset") 658 .Input("input_dataset: variant") 659 .Input("buffer_size: int64") 660 .Input("seed: int64") 661 .Input("seed2: int64") 662 .Input("count: int64") 663 .Output("handle: variant") 664 .Attr("output_types: list(type) >= 1") 665 .Attr("output_shapes: list(shape) >= 1") 666 .Attr("reshuffle_each_iteration: bool = true") 667 .Attr("metadata: string = ''") 668 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 669 "output_types")) __anon4377504f1302(shape_inference::InferenceContext* c) 670 .SetShapeFn([](shape_inference::InferenceContext* c) { 671 shape_inference::ShapeHandle unused; 672 // buffer_size, seed, seed2, and count should be scalars. 673 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 674 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 675 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 676 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 677 return shape_inference::ScalarShape(c); 678 }); 679 680 REGISTER_OP("ShuffleAndRepeatDatasetV2") 681 .Input("input_dataset: variant") 682 .Input("buffer_size: int64") 683 .Input("seed: int64") 684 .Input("seed2: int64") 685 .Input("count: int64") 686 .Input("seed_generator: resource") 687 .Output("handle: variant") 688 .Attr("reshuffle_each_iteration: bool = true") 689 .Attr("output_types: list(type) >= 1") 690 .Attr("output_shapes: list(shape) >= 1") 691 .Attr("metadata: string = ''") 692 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 693 "output_types")) __anon4377504f1402(shape_inference::InferenceContext* c) 694 .SetShapeFn([](shape_inference::InferenceContext* c) { 695 shape_inference::ShapeHandle unused; 696 // buffer_size, seed, seed2, count, and seed_generator should be scalars. 697 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 698 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 699 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 700 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 701 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); 702 return shape_inference::ScalarShape(c); 703 }); 704 705 REGISTER_OP("AnonymousMemoryCache") 706 .Output("handle: resource") 707 .Output("deleter: variant") __anon4377504f1502(shape_inference::InferenceContext* c) 708 .SetShapeFn([](shape_inference::InferenceContext* c) { 709 c->set_output(0, c->Scalar()); 710 c->set_output(1, c->Scalar()); 711 return OkStatus(); 712 }); 713 714 REGISTER_OP("DeleteMemoryCache") 715 .Input("handle: resource") 716 .Input("deleter: variant") 717 .SetShapeFn(shape_inference::NoOutputs); 718 719 REGISTER_OP("DummyMemoryCache") 720 .Output("handle: resource") __anon4377504f1602(shape_inference::InferenceContext* c) 721 .SetShapeFn([](shape_inference::InferenceContext* c) { 722 c->set_output(0, c->Scalar()); 723 return OkStatus(); 724 }); 725 726 REGISTER_OP("CacheDataset") 727 .Input("input_dataset: variant") 728 .Input("filename: string") 729 .Output("handle: variant") 730 .Attr("output_types: list(type) >= 1") 731 .Attr("output_shapes: list(shape) >= 1") 732 .Attr("metadata: string = ''") 733 // TODO(mdan): Should these use type inference instead? 734 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 735 "output_types")) __anon4377504f1702(shape_inference::InferenceContext* c) 736 .SetShapeFn([](shape_inference::InferenceContext* c) { 737 shape_inference::ShapeHandle unused; 738 // filename should be a scalar. 739 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 740 return shape_inference::ScalarShape(c); 741 }); 742 743 REGISTER_OP("CacheDatasetV2") 744 .Input("input_dataset: variant") 745 .Input("filename: string") 746 .Input("cache: resource") 747 .Output("handle: variant") 748 .Attr("output_types: list(type) >= 1") 749 .Attr("output_shapes: list(shape) >= 1") 750 .Attr("metadata: string = ''") 751 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 752 "output_types")) __anon4377504f1802(shape_inference::InferenceContext* c) 753 .SetShapeFn([](shape_inference::InferenceContext* c) { 754 shape_inference::ShapeHandle unused; 755 // filename should be a scalar. 756 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 757 // cache should be a scalar. 758 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 759 return shape_inference::ScalarShape(c); 760 }); 761 762 REGISTER_OP("TextLineDataset") 763 .Input("filenames: string") 764 .Input("compression_type: string") 765 .Input("buffer_size: int64") 766 .Attr("metadata: string = ''") 767 .Output("handle: variant") 768 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 769 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, 770 TFT_STRING)) __anon4377504f1902(shape_inference::InferenceContext* c) 771 .SetShapeFn([](shape_inference::InferenceContext* c) { 772 shape_inference::ShapeHandle unused; 773 // `filenames` must be a scalar or a vector. 774 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 775 // `compression_type` could only be a scalar. 776 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 777 // `buffer_size` could only be a scalar. 778 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 779 return shape_inference::ScalarShape(c); 780 }); 781 782 REGISTER_OP("FixedLengthRecordDataset") 783 .Input("filenames: string") 784 .Input("header_bytes: int64") 785 .Input("record_bytes: int64") 786 .Input("footer_bytes: int64") 787 .Input("buffer_size: int64") 788 .Attr("metadata: string = ''") 789 .Output("handle: variant") 790 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 791 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, 792 TFT_STRING)) __anon4377504f1a02(shape_inference::InferenceContext* c) 793 .SetShapeFn([](shape_inference::InferenceContext* c) { 794 shape_inference::ShapeHandle unused; 795 // `filenames` must be a scalar or a vector. 796 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 797 // header_bytes, record_bytes, footer_bytes, buffer_size should be 798 // scalars. 799 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 800 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 801 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 802 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 803 return shape_inference::ScalarShape(c); 804 }); 805 806 REGISTER_OP("FixedLengthRecordDatasetV2") 807 .Input("filenames: string") 808 .Input("header_bytes: int64") 809 .Input("record_bytes: int64") 810 .Input("footer_bytes: int64") 811 .Input("buffer_size: int64") 812 .Input("compression_type: string") 813 .Attr("metadata: string = ''") 814 .Output("handle: variant") 815 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 816 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, 817 TFT_STRING)) __anon4377504f1b02(shape_inference::InferenceContext* c) 818 .SetShapeFn([](shape_inference::InferenceContext* c) { 819 shape_inference::ShapeHandle unused; 820 // `filenames` must be a scalar or a vector. 821 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 822 // header_bytes, record_bytes, footer_bytes, buffer_size should be 823 // scalars. 824 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 825 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 826 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); 827 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); 828 return shape_inference::ScalarShape(c); 829 }); 830 831 REGISTER_OP("TFRecordDataset") 832 .Input("filenames: string") 833 .Input("compression_type: string") 834 .Input("buffer_size: int64") 835 .Attr("metadata: string = ''") 836 .Output("handle: variant") 837 .SetDoNotOptimize() // TODO(b/123753214): See comment in dataset_ops.cc. 838 .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, 839 TFT_STRING)) __anon4377504f1c02(shape_inference::InferenceContext* c) 840 .SetShapeFn([](shape_inference::InferenceContext* c) { 841 shape_inference::ShapeHandle unused; 842 // `filenames` must be a scalar or a vector. 843 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); 844 // `compression_type` could only be a scalar. 845 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 846 // `buffer_size` could only be a scalar. 847 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 848 return shape_inference::ScalarShape(c); 849 }); 850 851 REGISTER_OP("Iterator") 852 .Output("handle: resource") 853 .Attr("shared_name: string") 854 .Attr("container: string") 855 .Attr("output_types: list(type) >= 1") 856 .Attr("output_shapes: list(shape) >= 1") 857 .SetShapeFn(shape_inference::ScalarShape); 858 859 REGISTER_OP("IteratorV2") 860 .Output("handle: resource") 861 .Attr("shared_name: string") 862 .Attr("container: string") 863 .Attr("output_types: list(type) >= 1") 864 .Attr("output_shapes: list(shape) >= 1") 865 .SetShapeFn(shape_inference::ScalarShape); 866 867 REGISTER_OP("AnonymousIterator") 868 .Output("handle: resource") 869 .Attr("output_types: list(type) >= 1") 870 .Attr("output_shapes: list(shape) >= 1") 871 .SetShapeFn(shape_inference::ScalarShape); 872 873 REGISTER_OP("AnonymousIteratorV2") 874 .Output("handle: resource") 875 .Output("deleter: variant") 876 .Attr("output_types: list(type) >= 1") 877 .Attr("output_shapes: list(shape) >= 1") __anon4377504f1d02(shape_inference::InferenceContext* c) 878 .SetShapeFn([](shape_inference::InferenceContext* c) { 879 c->set_output(0, c->Scalar()); 880 c->set_output(1, c->Scalar()); 881 return OkStatus(); 882 }); 883 884 REGISTER_OP("AnonymousIteratorV3") 885 .Output("handle: resource") 886 .Attr("output_types: list(type) >= 1") 887 .Attr("output_shapes: list(shape) >= 1") __anon4377504f1e02(shape_inference::InferenceContext* c) 888 .SetShapeFn([](shape_inference::InferenceContext* c) { 889 c->set_output(0, c->Scalar()); 890 return OkStatus(); 891 }); 892 893 REGISTER_OP("DeleteIterator") 894 .Input("handle: resource") 895 .Input("deleter: variant") 896 .SetShapeFn(shape_inference::NoOutputs); 897 898 REGISTER_OP("DeleteMultiDeviceIterator") 899 .Input("multi_device_iterator: resource") 900 .Input("iterators: N * resource") 901 .Input("deleter: variant") 902 .Attr("N: int >= 0") 903 .SetShapeFn(shape_inference::NoOutputs); 904 905 REGISTER_OP("MakeIterator") 906 .Input("dataset: variant") 907 .Input("iterator: resource") 908 .SetTypeConstructor(full_type::NoOutputs()) 909 .SetReverseTypeFn(1, full_type::MapCovariant(TFT_DATASET, TFT_ITERATOR, 0)) 910 .SetShapeFn(shape_inference::NoOutputs); 911 912 REGISTER_OP("OneShotIterator") 913 .Output("handle: resource") 914 .Attr("dataset_factory: func") 915 .Attr("output_types: list(type) >= 1") 916 .Attr("output_shapes: list(shape) >= 1") 917 .Attr("container: string = ''") 918 .Attr("shared_name: string = ''") 919 .SetIsStateful() 920 .SetShapeFn(shape_inference::ScalarShape); 921 922 REGISTER_OP("IteratorGetNext") 923 .Input("iterator: resource") 924 .Output("components: output_types") 925 .Attr("output_types: list(type) >= 1") 926 .Attr("output_shapes: list(shape) >= 1") 927 .SetShapeFn(shape_inference::DatasetIteratorShape); 928 929 REGISTER_OP("IteratorGetNextSync") 930 .Input("iterator: resource") 931 .Output("components: output_types") 932 .Attr("output_types: list(type) >= 1") 933 .Attr("output_shapes: list(shape) >= 1") 934 .SetShapeFn(shape_inference::DatasetIteratorShape); 935 936 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 937 // implement a mechanism to determine whether `dataset` has a side-effect 938 // and use it to decide whether to use a stateless or stateful version of this 939 // op. 940 REGISTER_OP("DatasetToSingleElement") 941 .Input("dataset: variant") 942 .Output("components: output_types") 943 .Attr("output_types: list(type) >= 1") 944 .Attr("output_shapes: list(shape) >= 1") 945 .Attr("metadata: string = ''") 946 .SetIsStateful() 947 .SetShapeFn(shape_inference::DatasetIteratorShape); 948 949 // TODO(b/124308596): Instead of conservatively marking this op as stateful, 950 // implement a mechanism to determine whether `dataset` has a side-effect 951 // and use it to decide whether to use a stateless or stateful version of this 952 // op. 953 REGISTER_OP("ReduceDataset") 954 .Input("input_dataset: variant") 955 .Input("initial_state: Tstate") 956 .Input("other_arguments: Targuments") 957 .Output("components: output_types") 958 .Attr("f: func") 959 .Attr("Tstate: list(type) >= 1") 960 .Attr("Targuments: list(type) >= 0") 961 .Attr("output_types: list(type) >= 1") 962 .Attr("output_shapes: list(shape) >= 1") 963 .Attr("use_inter_op_parallelism: bool = true") 964 .Attr("metadata: string = ''") 965 .SetIsStateful() 966 .SetShapeFn(shape_inference::DatasetIteratorShape); 967 968 REGISTER_OP("IteratorToStringHandle") 969 .Input("resource_handle: resource") 970 .Output("string_handle: string") 971 .SetShapeFn(shape_inference::ScalarShape); 972 973 REGISTER_OP("IteratorFromStringHandle") 974 .Input("string_handle: string") 975 .Output("resource_handle: resource") 976 .Attr("output_types: list(type) >= 0 = []") 977 .Attr("output_shapes: list(shape) >= 0 = []") 978 .SetShapeFn(shape_inference::ScalarShape); 979 980 REGISTER_OP("IteratorFromStringHandleV2") 981 .Input("string_handle: string") 982 .Output("resource_handle: resource") 983 .Attr("output_types: list(type) >= 0 = []") 984 .Attr("output_shapes: list(shape) >= 0 = []") 985 .SetShapeFn(shape_inference::ScalarShape); 986 987 REGISTER_OP("SerializeIterator") 988 .Input("resource_handle: resource") 989 .Attr("external_state_policy: int = 0") 990 .Output("serialized: variant") __anon4377504f1f02(shape_inference::InferenceContext* c) 991 .SetShapeFn([](shape_inference::InferenceContext* c) { 992 c->set_output(0, c->Vector(c->UnknownDim())); 993 return OkStatus(); 994 }); 995 996 REGISTER_OP("DeserializeIterator") 997 .Input("resource_handle: resource") 998 .Input("serialized: variant") 999 .SetShapeFn(shape_inference::NoOutputs); 1000 1001 REGISTER_OP("DatasetToGraph") 1002 .Input("input_dataset: variant") 1003 .Attr("stateful_whitelist: list(string) >= 0 = []") 1004 .Attr("allow_stateful: bool = false") 1005 .Attr("strip_device_assignment: bool = false") 1006 .Output("graph: string") 1007 .SetShapeFn(shape_inference::ScalarShape); 1008 1009 REGISTER_OP("DatasetToGraphV2") 1010 .Input("input_dataset: variant") 1011 .Attr("external_state_policy: int = 0") 1012 .Attr("strip_device_assignment: bool = false") 1013 .Output("graph: string") 1014 .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0)) 1015 .SetShapeFn(shape_inference::ScalarShape); 1016 1017 REGISTER_OP("OptimizeDataset") 1018 .Input("input_dataset: variant") 1019 .Input("optimizations: string") 1020 .Output("handle: variant") 1021 .Attr("output_types: list(type) >= 1") 1022 .Attr("output_shapes: list(shape) >= 1") 1023 .Attr("optimization_configs: list(string) = []") 1024 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1025 "output_types")) 1026 .SetShapeFn(shape_inference::ScalarShape); 1027 1028 REGISTER_OP("OptimizeDatasetV2") 1029 .Input("input_dataset: variant") 1030 .Input("optimizations_enabled: string") 1031 .Input("optimizations_disabled: string") 1032 .Input("optimizations_default: string") 1033 .Output("handle: variant") 1034 .Attr("output_types: list(type) >= 1") 1035 .Attr("output_shapes: list(shape) >= 1") 1036 .Attr("optimization_configs: list(string) = []") 1037 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1038 "output_types")) 1039 .SetShapeFn(shape_inference::ScalarShape); 1040 1041 REGISTER_OP("OptionalFromValue") 1042 .Input("components: Toutput_types") 1043 .Output("optional: variant") 1044 .Attr("Toutput_types: list(type) >= 1") 1045 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL, 1046 "Toutput_types")) __anon4377504f2002(shape_inference::InferenceContext* c) 1047 .SetShapeFn([](shape_inference::InferenceContext* c) { 1048 std::vector<DataType> dtypes; 1049 TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes)); 1050 c->set_output(0, c->Scalar()); 1051 std::vector<shape_inference::ShapeAndType> shapes_and_types; 1052 shapes_and_types.reserve(c->num_inputs()); 1053 const FullTypeDef& ret_types = c->ret_types(); 1054 for (int i = 0; i < c->num_inputs(); ++i) { 1055 // TODO(mdan): output_type(i) == optional is incorrect. 1056 // "Optional" is the type of the whole container, not of individual 1057 // elements. 1058 // 1059 // Why ret_types.args(0) and not args(i) -- 1060 // For example if Toutput_types is (int32, float32), then 1061 // ret_types.args[0] (i.e. the 0th output) is 1062 // Optional[Record[Tensor[int32, s1], Tensor[float32, s2]]] 1063 // set_output_handle_shapes_and_types tracks the same thing, but in 1064 // a transposed way: 1065 // {ShapeAndType(in32, s1, Optional), ShapeAndType(in32, s2, Optional)} 1066 // That should be corrected in the future (see todo above). 1067 shapes_and_types.emplace_back(c->input(i), dtypes[i], 1068 ret_types.args(0)); 1069 } 1070 c->set_output_handle_shapes_and_types(0, shapes_and_types); 1071 return OkStatus(); 1072 }); 1073 1074 REGISTER_OP("OptionalNone") 1075 .Output("optional: variant") 1076 .SetShapeFn(shape_inference::ScalarShape); 1077 1078 REGISTER_OP("OptionalHasValue") 1079 .Input("optional: variant") 1080 .Output("has_value: bool") 1081 .SetShapeFn(shape_inference::ScalarShape); 1082 1083 REGISTER_OP("OptionalGetValue") 1084 .Input("optional: variant") 1085 .Output("components: output_types") 1086 .Attr("output_types: list(type) >= 1") 1087 .Attr("output_shapes: list(shape) >= 1") 1088 .SetShapeFn(shape_inference::DatasetIteratorShape); 1089 1090 REGISTER_OP("IteratorGetNextAsOptional") 1091 .Input("iterator: resource") 1092 .Output("optional: variant") 1093 .Attr("output_types: list(type) >= 1") 1094 .Attr("output_shapes: list(shape) >= 1") 1095 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL, 1096 "output_types")) 1097 .SetForwardTypeFn(full_type::MapCovariant(TFT_ITERATOR, TFT_OPTIONAL, 0)) 1098 .SetShapeFn(shape_inference::ScalarShape); 1099 1100 REGISTER_OP("ModelDataset") 1101 .Input("input_dataset: variant") 1102 .Output("handle: variant") 1103 .Attr("algorithm: int = 0") 1104 .Attr("cpu_budget: int = 0") 1105 .Attr("ram_budget: int = 0") 1106 .Attr("output_types: list(type) >= 1") 1107 .Attr("output_shapes: list(shape) >= 1") 1108 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1109 "output_types")) 1110 .SetShapeFn(shape_inference::ScalarShape); 1111 1112 // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f` 1113 // is stateful. 1114 REGISTER_OP("MapDefun") 1115 .Input("arguments: Targuments") 1116 .Input("captured_inputs: Tcaptured") 1117 .Output("output: output_types") 1118 .Attr("Targuments: list(type) >= 1") 1119 .Attr("Tcaptured: list(type) >= 0 = []") 1120 .Attr("output_types: list(type) >= 1") 1121 .Attr("output_shapes: list(shape) >= 1") 1122 .Attr("f: func") 1123 .Attr("max_intra_op_parallelism: int = 1") __anon4377504f2102(shape_inference::InferenceContext* c) 1124 .SetShapeFn([](shape_inference::InferenceContext* c) { 1125 std::vector<PartialTensorShape> output_shapes; 1126 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 1127 DataTypeVector t_args; 1128 TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); 1129 if (output_shapes.size() != c->num_outputs()) { 1130 return errors::InvalidArgument( 1131 "`output_shapes` must be the same length as `output_types` (", 1132 output_shapes.size(), " vs. ", c->num_outputs(), ")"); 1133 } 1134 1135 int64_t dim_zero = -1; 1136 for (size_t i = 0; i < t_args.size(); ++i) { 1137 if (c->Rank(c->input(i)) == 0) { 1138 return errors::InvalidArgument( 1139 "Arguments must have rank at least 1. Input ", i, 1140 " has rank of 0."); 1141 } 1142 auto dim_handle = c->Dim(c->input(i), 0); 1143 if (c->ValueKnown(dim_handle)) { 1144 if (dim_zero == -1) { 1145 dim_zero = c->Value(dim_handle); 1146 } else if (c->Value(dim_handle) != dim_zero) { 1147 return errors::InvalidArgument( 1148 "Arguments must have the same dimension 0."); 1149 } 1150 } 1151 } 1152 1153 for (size_t i = 0; i < output_shapes.size(); ++i) { 1154 PartialTensorShape s({}); 1155 s = s.Concatenate(dim_zero); 1156 s = s.Concatenate(output_shapes[i]); 1157 shape_inference::ShapeHandle output_shape_handle; 1158 1159 TF_RETURN_IF_ERROR( 1160 c->MakeShapeFromPartialTensorShape(s, &output_shape_handle)); 1161 c->set_output(static_cast<int>(i), output_shape_handle); 1162 } 1163 return OkStatus(); 1164 }); 1165 1166 REGISTER_OP("WrapDatasetVariant") 1167 .Input("input_handle: variant") 1168 .Output("output_handle: variant") 1169 .SetShapeFn(shape_inference::ScalarShape); 1170 1171 REGISTER_OP("UnwrapDatasetVariant") 1172 .Input("input_handle: variant") 1173 .Output("output_handle: variant") 1174 .SetShapeFn(shape_inference::ScalarShape); 1175 1176 REGISTER_OP("AnonymousMultiDeviceIterator") 1177 .Output("handle: resource") 1178 .Output("deleter: variant") 1179 .Attr("devices: list(string) >= 1") 1180 .Attr("output_types: list(type) >= 1") 1181 .Attr("output_shapes: list(shape) >= 1") __anon4377504f2202(shape_inference::InferenceContext* c) 1182 .SetShapeFn([](shape_inference::InferenceContext* c) { 1183 c->set_output(0, c->Scalar()); 1184 c->set_output(1, c->Scalar()); 1185 return OkStatus(); 1186 }); 1187 1188 REGISTER_OP("AnonymousMultiDeviceIteratorV3") 1189 .Output("handle: resource") 1190 .Attr("devices: list(string) >= 1") 1191 .Attr("output_types: list(type) >= 1") 1192 .Attr("output_shapes: list(shape) >= 1") __anon4377504f2302(shape_inference::InferenceContext* c) 1193 .SetShapeFn([](shape_inference::InferenceContext* c) { 1194 c->set_output(0, c->Scalar()); 1195 return OkStatus(); 1196 }); 1197 1198 REGISTER_OP("MultiDeviceIterator") 1199 .Output("handle: resource") 1200 .Attr("devices: list(string) >= 1") 1201 .Attr("shared_name: string") 1202 .Attr("container: string") 1203 .Attr("output_types: list(type) >= 1") 1204 .Attr("output_shapes: list(shape) >= 1") 1205 .SetShapeFn(shape_inference::ScalarShape); 1206 1207 REGISTER_OP("MultiDeviceIteratorInit") 1208 .Input("dataset: variant") 1209 .Input("multi_device_iterator: resource") 1210 .Input("max_buffer_size: int64") 1211 .Output("incarnation_id: int64") 1212 .SetShapeFn(shape_inference::ScalarShape); 1213 1214 REGISTER_OP("MultiDeviceIteratorGetNextFromShard") 1215 .Input("multi_device_iterator: resource") 1216 .Input("shard_num: int32") 1217 .Input("incarnation_id: int64") 1218 .Output("components: output_types") 1219 .Attr("output_types: list(type) >= 1") 1220 .Attr("output_shapes: list(shape) >= 1") 1221 .SetShapeFn(shape_inference::DatasetIteratorShape); 1222 1223 REGISTER_OP("MultiDeviceIteratorToStringHandle") 1224 .Input("multi_device_iterator: resource") 1225 .Output("string_handle: string") 1226 .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0)) 1227 .SetShapeFn(shape_inference::ScalarShape); 1228 1229 REGISTER_OP("MultiDeviceIteratorFromStringHandle") 1230 .Input("string_handle: string") 1231 .Output("multi_device_iterator: resource") 1232 .Attr("output_types: list(type) >= 0 = []") 1233 .Attr("output_shapes: list(shape) >= 0 = []") 1234 .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0)) 1235 .SetShapeFn(shape_inference::ScalarShape); 1236 1237 REGISTER_OP("OptionsDataset") 1238 .Input("input_dataset: variant") 1239 .Output("handle: variant") 1240 .Attr("serialized_options: string") 1241 .Attr("output_types: list(type) >= 1") 1242 .Attr("output_shapes: list(shape) >= 1") 1243 .Attr("metadata: string = ''") 1244 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1245 "output_types")) 1246 .SetShapeFn(shape_inference::ScalarShape); 1247 1248 REGISTER_OP("GetOptions") 1249 .Input("input_dataset: variant") 1250 .Output("serialized_options: string") 1251 .SetShapeFn(shape_inference::ScalarShape); 1252 1253 REGISTER_OP("FinalizeDataset") 1254 .Input("input_dataset: variant") 1255 .Output("handle: variant") 1256 .Attr("has_captured_ref: bool = false") 1257 .Attr("output_types: list(type) >= 1") 1258 .Attr("output_shapes: list(shape) >= 1") 1259 .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET, 1260 "output_types")) 1261 .SetShapeFn(shape_inference::ScalarShape); 1262 1263 } // namespace tensorflow 1264