xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/dataset_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/common_shape_fns.h"
16 #include "tensorflow/core/framework/full_type.pb.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/op_def_builder.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 
23 // --------------------------------------------------------------------------
24 
25 // The ops in this section can be composed to define an input
26 // pipeline. Each op produces a DT_VARIANT tensor that represents
27 // a DAG of "dataset" objects. An "dataset" object can be converted
28 // to a stateful "iterator" by passing the "dataset" to the
29 // "MakeIterator" op.
30 //
31 // TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are
32 // not presently serializable. To avoid issues with graph optimizations, such
33 // as constant folding, CSE, or DCE, ensure that any "source dataset" ops
34 // (i.e. ops that output a dataset and do not take one as input) are
35 // marked as "do not optimize".
36 
37 // TODO(mrry): Validate that `components` have shapes compatible with
38 // `output_shapes`.
39 REGISTER_OP("TensorDataset")
40     .Input("components: Toutput_types")
41     .Output("handle: variant")
42     .Attr("Toutput_types: list(type) >= 1")
43     .Attr("output_shapes: list(shape) >= 1")
44     .Attr("metadata: string = ''")
45     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
46     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
47                                                            "Toutput_types"))
48     .SetShapeFn(shape_inference::ScalarShape);
49 
50 // TODO(mrry): Validate that the dim-0 slices of `components` have shapes
51 // compatible with `output_shapes`.
52 REGISTER_OP("TensorSliceDataset")
53     .Input("components: Toutput_types")
54     .Output("handle: variant")
55     .Attr("Toutput_types: list(type) >= 1")
56     .Attr("output_shapes: list(shape) >= 1")
57     .Attr("is_files: bool = false")
58     .Attr("metadata: string = ''")
59     .Attr("replicate_on_split: bool = false")
60     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
61     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
62                                                            "Toutput_types"))
63     .SetForwardTypeFn(full_type::MultiaryUnstack(TFT_DATASET,
64                                                  full_type::UnstackTensor))
65     .SetShapeFn(shape_inference::ScalarShape);
66 
67 REGISTER_OP("SparseTensorSliceDataset")
68     .Input("indices: int64")
69     .Input("values: Tvalues")
70     .Input("dense_shape: int64")
71     .Output("handle: variant")
72     .Attr("Tvalues: type")
73     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
74     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET, "Tvalues"))
75     .SetShapeFn(shape_inference::ScalarShape);
76 
77 REGISTER_OP("GeneratorDataset")
78     .Input("init_func_other_args: Tinit_func_args")
79     .Input("next_func_other_args: Tnext_func_args")
80     .Input("finalize_func_other_args: Tfinalize_func_args")
81     .Output("handle: variant")
82     .Attr("init_func: func")
83     .Attr("next_func: func")
84     .Attr("finalize_func: func")
85     .Attr("Tinit_func_args: list(type) >= 0")
86     .Attr("Tnext_func_args: list(type) >= 0")
87     .Attr("Tfinalize_func_args: list(type) >= 0")
88     .Attr("output_types: list(type) >= 1")
89     .Attr("output_shapes: list(shape) >= 1")
90     .Attr("metadata: string = ''")
91     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
92     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
93                                                            "output_types"))
94     .SetShapeFn(shape_inference::ScalarShape);
95 
96 REGISTER_OP("ZipDataset")
97     .Input("input_datasets: N * variant")
98     .Output("handle: variant")
99     .Attr("output_types: list(type) >= 1")
100     .Attr("output_shapes: list(shape) >= 1")
101     .Attr("N: int >= 1")
102     .Attr("metadata: string = ''")
103     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
104                                                            "output_types"))
105     .SetShapeFn(shape_inference::ScalarShape);
106 
107 REGISTER_OP("ConcatenateDataset")
108     .Input("input_dataset: variant")
109     .Input("another_dataset: variant")
110     .Output("handle: variant")
111     .Attr("output_types: list(type) >= 1")
112     .Attr("output_shapes: list(shape) >= 1")
113     .Attr("metadata: string = ''")
114     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
115                                                            "output_types"))
116     .SetShapeFn(shape_inference::ScalarShape);
117 
118 REGISTER_OP("RepeatDataset")
119     .Input("input_dataset: variant")
120     .Input("count: int64")
121     .Output("handle: variant")
122     .Attr("output_types: list(type) >= 1")
123     .Attr("output_shapes: list(shape) >= 1")
124     .Attr("metadata: string = ''")
125     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
126                                                            "output_types"))
__anon4377504f0102(shape_inference::InferenceContext* c) 127     .SetShapeFn([](shape_inference::InferenceContext* c) {
128       shape_inference::ShapeHandle count_shape;
129       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
130       return shape_inference::ScalarShape(c);
131     });
132 
133 REGISTER_OP("TakeDataset")
134     .Input("input_dataset: variant")
135     .Input("count: int64")
136     .Output("handle: variant")
137     .Attr("output_types: list(type) >= 1")
138     .Attr("output_shapes: list(shape) >= 1")
139     .Attr("metadata: string = ''")
140     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
141                                                            "output_types"))
__anon4377504f0202(shape_inference::InferenceContext* c) 142     .SetShapeFn([](shape_inference::InferenceContext* c) {
143       shape_inference::ShapeHandle count_shape;
144       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
145       return shape_inference::ScalarShape(c);
146     });
147 
148 REGISTER_OP("SkipDataset")
149     .Input("input_dataset: variant")
150     .Input("count: int64")
151     .Output("handle: variant")
152     .Attr("output_types: list(type) >= 1")
153     .Attr("output_shapes: list(shape) >= 1")
154     .Attr("metadata: string = ''")
155     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
156                                                            "output_types"))
__anon4377504f0302(shape_inference::InferenceContext* c) 157     .SetShapeFn([](shape_inference::InferenceContext* c) {
158       shape_inference::ShapeHandle count_shape;
159       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &count_shape));
160       return shape_inference::ScalarShape(c);
161     });
162 
163 REGISTER_OP("MapDataset")
164     .Input("input_dataset: variant")
165     .Input("other_arguments: Targuments")
166     .Output("handle: variant")
167     .Attr("f: func")
168     .Attr("Targuments: list(type) >= 0")
169     .Attr("output_types: list(type) >= 1")
170     .Attr("output_shapes: list(shape) >= 1")
171     .Attr("use_inter_op_parallelism: bool = true")
172     .Attr("preserve_cardinality: bool = false")
173     .Attr("metadata: string = ''")
174     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
175                                                            "output_types"))
176     .SetShapeFn(shape_inference::ScalarShape);
177 
178 REGISTER_OP("ParallelMapDataset")
179     .Input("input_dataset: variant")
180     .Input("other_arguments: Targuments")
181     .Input("num_parallel_calls: int32")
182     .Output("handle: variant")
183     .Attr("f: func")
184     .Attr("Targuments: list(type) >= 0")
185     .Attr("output_types: list(type) >= 1")
186     .Attr("output_shapes: list(shape) >= 1")
187     .Attr("use_inter_op_parallelism: bool = true")
188     .Attr("sloppy: bool = false")
189     .Attr("preserve_cardinality: bool = false")
190     .Attr("metadata: string = ''")
191     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
192                                                            "output_types"))
193     .SetShapeFn(shape_inference::ScalarShape);
194 
195 REGISTER_OP("ParallelMapDatasetV2")
196     .Input("input_dataset: variant")
197     .Input("other_arguments: Targuments")
198     .Input("num_parallel_calls: int64")
199     .Output("handle: variant")
200     .Attr("f: func")
201     .Attr("Targuments: list(type) >= 0")
202     .Attr("output_types: list(type) >= 1")
203     .Attr("output_shapes: list(shape) >= 1")
204     .Attr("use_inter_op_parallelism: bool = true")
205     // "true", "false", or "default".
206     .Attr("deterministic: string = 'default'")
207     .Attr("preserve_cardinality: bool = false")
208     .Attr("metadata: string = ''")
209     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
210                                                            "output_types"))
211     .SetShapeFn(shape_inference::ScalarShape);
212 
213 REGISTER_OP("PrefetchDataset")
214     .Input("input_dataset: variant")
215     .Input("buffer_size: int64")
216     .Output("handle: variant")
217     .Attr("output_types: list(type) >= 1")
218     .Attr("output_shapes: list(shape) >= 1")
219     .Attr("slack_period: int = 0")
220     .Attr("legacy_autotune: bool = true")
221     .Attr("buffer_size_min: int = 0")
222     .Attr("metadata: string = ''")
223     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
224                                                            "output_types"))
__anon4377504f0402(shape_inference::InferenceContext* c) 225     .SetShapeFn([](shape_inference::InferenceContext* c) {
226       shape_inference::ShapeHandle unused;
227       // buffer_size should be a scalar.
228       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
229       return shape_inference::ScalarShape(c);
230     });
231 
232 REGISTER_OP("FlatMapDataset")
233     .Input("input_dataset: variant")
234     .Input("other_arguments: Targuments")
235     .Output("handle: variant")
236     .Attr("f: func")
237     .Attr("Targuments: list(type) >= 0")
238     .Attr("output_types: list(type) >= 1")
239     .Attr("output_shapes: list(shape) >= 1")
240     .Attr("metadata: string = ''")
241     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
242                                                            "output_types"))
243     .SetShapeFn(shape_inference::ScalarShape);
244 
245 REGISTER_OP("InterleaveDataset")
246     .Input("input_dataset: variant")
247     .Input("other_arguments: Targuments")
248     .Input("cycle_length: int64")
249     .Input("block_length: int64")
250     .Output("handle: variant")
251     .Attr("f: func")
252     .Attr("Targuments: list(type) >= 0")
253     .Attr("output_types: list(type) >= 1")
254     .Attr("output_shapes: list(shape) >= 1")
255     .Attr("metadata: string = ''")
256     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
257                                                            "output_types"))
258     .SetShapeFn(shape_inference::ScalarShape);
259 
260 REGISTER_OP("ParallelInterleaveDatasetV2")
261     .Input("input_dataset: variant")
262     .Input("other_arguments: Targuments")
263     .Input("cycle_length: int64")
264     .Input("block_length: int64")
265     .Input("num_parallel_calls: int64")
266     .Output("handle: variant")
267     .Attr("f: func")
268     .Attr("Targuments: list(type) >= 0")
269     .Attr("output_types: list(type) >= 1")
270     .Attr("output_shapes: list(shape) >= 1")
271     .Attr("sloppy: bool = false")
272     .Attr("metadata: string = ''")
273     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
274                                                            "output_types"))
275     .SetShapeFn(shape_inference::ScalarShape);
276 
277 REGISTER_OP("ParallelInterleaveDatasetV3")
278     .Input("input_dataset: variant")
279     .Input("other_arguments: Targuments")
280     .Input("cycle_length: int64")
281     .Input("block_length: int64")
282     .Input("num_parallel_calls: int64")
283     .Output("handle: variant")
284     .Attr("f: func")
285     // "true", "false", or "default".
286     .Attr("deterministic: string = 'default'")
287     .Attr("Targuments: list(type) >= 0")
288     .Attr("output_types: list(type) >= 1")
289     .Attr("output_shapes: list(shape) >= 1")
290     .Attr("metadata: string = ''")
291     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
292                                                            "output_types"))
293     .SetShapeFn(shape_inference::ScalarShape);
294 
295 // Like V3, but adds buffer_output_elements and prefetch_input_elements.
296 REGISTER_OP("ParallelInterleaveDatasetV4")
297     .Input("input_dataset: variant")
298     .Input("other_arguments: Targuments")
299     .Input("cycle_length: int64")
300     .Input("block_length: int64")
301     .Input("buffer_output_elements: int64")
302     .Input("prefetch_input_elements: int64")
303     .Input("num_parallel_calls: int64")
304     .Output("handle: variant")
305     .Attr("f: func")
306     // "true", "false", or "default".
307     .Attr("deterministic: string = 'default'")
308     .Attr("Targuments: list(type) >= 0")
309     .Attr("output_types: list(type) >= 1")
310     .Attr("output_shapes: list(shape) >= 1")
311     .Attr("metadata: string = ''")
312     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
313                                                            "output_types"))
314     .SetShapeFn(shape_inference::ScalarShape);
315 
316 REGISTER_OP("FilterDataset")
317     .Input("input_dataset: variant")
318     .Input("other_arguments: Targuments")
319     .Output("handle: variant")
320     .Attr("predicate: func")
321     .Attr("Targuments: list(type) >= 0")
322     .Attr("output_types: list(type) >= 1")
323     .Attr("output_shapes: list(shape) >= 1")
324     .Attr("metadata: string = ''")
325     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
326                                                            "output_types"))
327     .SetShapeFn(shape_inference::ScalarShape);
328 
329 REGISTER_OP("ParallelFilterDataset")
330     .Input("input_dataset: variant")
331     .Input("other_arguments: Targuments")
332     .Input("num_parallel_calls: int64")
333     .Output("handle: variant")
334     .Attr("predicate: func")
335     // "true", "false", or "default".
336     .Attr("deterministic: string = 'default'")
337     .Attr("Targuments: list(type) >= 0")
338     .Attr("output_types: list(type) >= 1")
339     .Attr("output_shapes: list(shape) >= 1")
340     .Attr("metadata: string = ''")
341     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
342                                                            "output_types"))
343     .SetShapeFn(shape_inference::ScalarShape);
344 
345 // This op is no longer supported.
346 REGISTER_OP("FilterByLastComponentDataset")
347     .Input("input_dataset: variant")
348     .Output("output: variant")
349     .Attr("output_types: list(type) >= 1")
350     .Attr("output_shapes: list(shape) >= 1")
351     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
352                                                            "output_types"))
353     .SetShapeFn(shape_inference::ScalarShape);
354 
355 REGISTER_OP("WindowDataset")
356     .Input("input_dataset: variant")
357     .Input("size: int64")
358     .Input("shift: int64")
359     .Input("stride: int64")
360     .Input("drop_remainder: bool")
361     .Output("handle: variant")
362     .Attr("output_types: list(type) >= 1")
363     .Attr("output_shapes: list(shape) >= 1")
364     .Attr("metadata: string = ''")
365     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
366                                                            "output_types"))
__anon4377504f0502(shape_inference::InferenceContext* c) 367     .SetShapeFn([](shape_inference::InferenceContext* c) {
368       shape_inference::ShapeHandle unused;
369       // size, shift, stride, and drop_remainder should be scalars.
370       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
371       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
372       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
373       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
374       return shape_inference::ScalarShape(c);
375     });
376 
377 REGISTER_OP("WindowOp")
378     .Input("inputs: Tinputs")
379     .Output("handle: variant")
380     .Attr("output_types: list(type) >= 1")
381     .Attr("output_shapes: list(shape) >= 1")
382     .Attr("Tinputs: list(type) >= 1")
383     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
384                                                            "output_types"))
385     .SetShapeFn(shape_inference::ScalarShape);
386 
387 REGISTER_OP("BatchDataset")
388     .Input("input_dataset: variant")
389     .Input("batch_size: int64")
390     .Output("handle: variant")
391     .Attr("output_types: list(type) >= 1")
392     .Attr("output_shapes: list(shape) >= 1")
393     .Attr("metadata: string = ''")
394     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
395                                                            "output_types"))
__anon4377504f0602(shape_inference::InferenceContext* c) 396     .SetShapeFn([](shape_inference::InferenceContext* c) {
397       shape_inference::ShapeHandle unused;
398       // batch_size should be a scalar.
399       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
400       return shape_inference::ScalarShape(c);
401     });
402 
403 REGISTER_OP("BatchDatasetV2")
404     .Input("input_dataset: variant")
405     .Input("batch_size: int64")
406     .Input("drop_remainder: bool")
407     .Output("handle: variant")
408     .Attr("parallel_copy: bool = false")
409     .Attr("output_types: list(type) >= 1")
410     .Attr("output_shapes: list(shape) >= 1")
411     .Attr("metadata: string = ''")
412     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
413                                                            "output_types"))
414     .SetForwardTypeFn(full_type::ContainerMap(TFT_DATASET, /*input_idx=*/0,
415                                               full_type::BatchTensor))
__anon4377504f0702(shape_inference::InferenceContext* c) 416     .SetShapeFn([](shape_inference::InferenceContext* c) {
417       shape_inference::ShapeHandle unused;
418       // batch_size should be a scalar.
419       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
420       // drop_remainder should be a scalar.
421       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
422       return shape_inference::ScalarShape(c);
423     });
424 
425 REGISTER_OP("ParallelBatchDataset")
426     .Input("input_dataset: variant")
427     .Input("batch_size: int64")
428     .Input("num_parallel_calls: int64")
429     .Input("drop_remainder: bool")
430     .Output("handle: variant")
431     .Attr("parallel_copy: bool = false")
432     .Attr("output_types: list(type) >= 1")
433     .Attr("output_shapes: list(shape) >= 1")
434     // "true", "false", or "default".
435     .Attr("deterministic: string = 'default'")
436     .Attr("metadata: string = ''")
437     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
438                                                            "output_types"))
__anon4377504f0802(shape_inference::InferenceContext* c) 439     .SetShapeFn([](shape_inference::InferenceContext* c) {
440       shape_inference::ShapeHandle unused;
441       // batch_size should be a scalar.
442       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
443       // num_parallel_calls should be a scalar.
444       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
445       // drop_remainder should be a scalar.
446       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
447       return shape_inference::ScalarShape(c);
448     });
449 
450 REGISTER_OP("ShardDataset")
451     .Input("input_dataset: variant")
452     .Input("num_shards: int64")
453     .Input("index: int64")
454     .Output("handle: variant")
455     .Attr("require_non_empty: bool = false")
456     .Attr("output_types: list(type) >= 1")
457     .Attr("output_shapes: list(shape) >= 1")
458     .Attr("metadata: string = ''")
459     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
460                                                            "output_types"))
__anon4377504f0902(shape_inference::InferenceContext* c) 461     .SetShapeFn([](shape_inference::InferenceContext* c) {
462       shape_inference::ShapeHandle unused;
463       // num_shards should be a scalar.
464       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
465       // index should be a scalar.
466       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
467       return shape_inference::ScalarShape(c);
468     });
469 
470 // TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
471 // `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
472 // possible to tell statically) compatible with `padded_shapes`, and that
473 // `padding_values` are all scalars.
474 REGISTER_OP("PaddedBatchDataset")
475     .Input("input_dataset: variant")
476     .Input("batch_size: int64")
477     .Input("padded_shapes: N * int64")
478     .Input("padding_values: Toutput_types")
479     .Output("handle: variant")
480     .Attr("Toutput_types: list(type) >= 1")
481     .Attr("output_shapes: list(shape) >= 1")
482     .Attr("N: int >= 1")
483     .Attr("metadata: string = ''")
484     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
485                                                            "Toutput_types"))
__anon4377504f0a02(shape_inference::InferenceContext* c) 486     .SetShapeFn([](shape_inference::InferenceContext* c) {
487       shape_inference::ShapeHandle unused;
488       // batch_size should be a scalar.
489       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
490       return shape_inference::ScalarShape(c);
491     });
492 
493 REGISTER_OP("PaddedBatchDatasetV2")
494     .Input("input_dataset: variant")
495     .Input("batch_size: int64")
496     .Input("padded_shapes: N * int64")
497     .Input("padding_values: Toutput_types")
498     .Input("drop_remainder: bool")
499     .Output("handle: variant")
500     .Attr("parallel_copy: bool = false")
501     .Attr("Toutput_types: list(type) >= 1")
502     .Attr("output_shapes: list(shape) >= 1")
503     .Attr("N: int >= 1")
504     .Attr("metadata: string = ''")
505     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
506                                                            "Toutput_types"))
__anon4377504f0b02(shape_inference::InferenceContext* c) 507     .SetShapeFn([](shape_inference::InferenceContext* c) {
508       shape_inference::ShapeHandle unused;
509       // batch_size should be a scalar.
510       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
511       // drop_remainder should be a scalar.
512       TF_RETURN_IF_ERROR(
513           c->WithRank(c->input(c->num_inputs() - 1), 0, &unused));
514       return shape_inference::ScalarShape(c);
515     });
516 
517 REGISTER_OP("RangeDataset")
518     .Input("start: int64")
519     .Input("stop: int64")
520     .Input("step: int64")
521     .Output("handle: variant")
522     .Attr("output_types: list(type) >= 1")
523     .Attr("output_shapes: list(shape) >= 1")
524     .Attr("metadata: string = ''")
525     .Attr("replicate_on_split: bool = false")
526     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
527     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
528                                                            "output_types"))
__anon4377504f0c02(shape_inference::InferenceContext* c) 529     .SetShapeFn([](shape_inference::InferenceContext* c) {
530       shape_inference::ShapeHandle unused;
531       // start, stop, and step should be scalars.
532       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
533       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
534       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
535       return shape_inference::ScalarShape(c);
536     });
537 
538 REGISTER_OP("RewriteDataset")
539     .Input("input_dataset: variant")
540     .Input("rewrite_name: string")
541     .Output("handle: variant")
542     .Attr("output_types: list(type) >= 1")
543     .Attr("output_shapes: list(shape) >= 1")
544     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
545                                                            "output_types"))
546     .SetShapeFn(shape_inference::ScalarShape);
547 
548 REGISTER_OP("AnonymousSeedGenerator")
549     .Input("seed: int64")
550     .Input("seed2: int64")
551     .Input("reshuffle: bool")
552     .Output("handle: resource")
553     .Output("deleter: variant")
__anon4377504f0d02(shape_inference::InferenceContext* c) 554     .SetShapeFn([](shape_inference::InferenceContext* c) {
555       c->set_output(0, c->Scalar());
556       c->set_output(1, c->Scalar());
557       return OkStatus();
558     });
559 
560 REGISTER_OP("DatasetCardinality")
561     .Input("input_dataset: variant")
562     .Output("cardinality: int64")
563     .SetShapeFn(shape_inference::ScalarShape);
564 
565 REGISTER_OP("DeleteSeedGenerator")
566     .Input("handle: resource")
567     .Input("deleter: variant")
568     .SetShapeFn(shape_inference::NoOutputs);
569 
570 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
571 REGISTER_OP("AnonymousRandomSeedGenerator")
572     .Input("seed: int64")
573     .Input("seed2: int64")
574     .Output("handle: resource")
575     .Output("deleter: variant")
__anon4377504f0e02(shape_inference::InferenceContext* c) 576     .SetShapeFn([](shape_inference::InferenceContext* c) {
577       c->set_output(0, c->Scalar());
578       c->set_output(1, c->Scalar());
579       return OkStatus();
580     });
581 
582 // Deprecated in favor of AnonymousSeedGenerator/DeleteSeedGenerator.
583 REGISTER_OP("DeleteRandomSeedGenerator")
584     .Input("handle: resource")
585     .Input("deleter: variant")
586     .SetShapeFn(shape_inference::NoOutputs);
587 
588 REGISTER_OP("DummySeedGenerator")
589     .Output("handle: resource")
__anon4377504f0f02(shape_inference::InferenceContext* c) 590     .SetShapeFn([](shape_inference::InferenceContext* c) {
591       c->set_output(0, c->Scalar());
592       return OkStatus();
593     });
594 
595 REGISTER_OP("ShuffleDataset")
596     .Input("input_dataset: variant")
597     .Input("buffer_size: int64")
598     .Input("seed: int64")
599     .Input("seed2: int64")
600     .Output("handle: variant")
601     .Attr("reshuffle_each_iteration: bool = true")
602     .Attr("output_types: list(type) >= 1")
603     .Attr("output_shapes: list(shape) >= 1")
604     .Attr("metadata: string = ''")
605     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
606                                                            "output_types"))
__anon4377504f1002(shape_inference::InferenceContext* c) 607     .SetShapeFn([](shape_inference::InferenceContext* c) {
608       shape_inference::ShapeHandle unused;
609       // buffer_size, seed, and seed2 should be scalars.
610       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
611       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
612       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
613       return shape_inference::ScalarShape(c);
614     });
615 
616 REGISTER_OP("ShuffleDatasetV2")
617     .Input("input_dataset: variant")
618     .Input("buffer_size: int64")
619     .Input("seed_generator: resource")
620     .Output("handle: variant")
621     .Attr("output_types: list(type) >= 1")
622     .Attr("output_shapes: list(shape) >= 1")
623     .Attr("metadata: string = ''")
624     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
625                                                            "output_types"))
__anon4377504f1102(shape_inference::InferenceContext* c) 626     .SetShapeFn([](shape_inference::InferenceContext* c) {
627       shape_inference::ShapeHandle unused;
628       // buffer_size and seed_generator should be scalars.
629       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
630       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
631       return shape_inference::ScalarShape(c);
632     });
633 
634 REGISTER_OP("ShuffleDatasetV3")
635     .Input("input_dataset: variant")
636     .Input("buffer_size: int64")
637     .Input("seed: int64")
638     .Input("seed2: int64")
639     .Input("seed_generator: resource")
640     .Output("handle: variant")
641     .Attr("reshuffle_each_iteration: bool = true")
642     .Attr("output_types: list(type) >= 1")
643     .Attr("output_shapes: list(shape) >= 1")
644     .Attr("metadata: string = ''")
645     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
646                                                            "output_types"))
__anon4377504f1202(shape_inference::InferenceContext* c) 647     .SetShapeFn([](shape_inference::InferenceContext* c) {
648       shape_inference::ShapeHandle unused;
649       // buffer_size, seed, seed2, and seed_generator should be scalars.
650       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
651       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
652       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
653       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
654       return shape_inference::ScalarShape(c);
655     });
656 
657 REGISTER_OP("ShuffleAndRepeatDataset")
658     .Input("input_dataset: variant")
659     .Input("buffer_size: int64")
660     .Input("seed: int64")
661     .Input("seed2: int64")
662     .Input("count: int64")
663     .Output("handle: variant")
664     .Attr("output_types: list(type) >= 1")
665     .Attr("output_shapes: list(shape) >= 1")
666     .Attr("reshuffle_each_iteration: bool = true")
667     .Attr("metadata: string = ''")
668     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
669                                                            "output_types"))
__anon4377504f1302(shape_inference::InferenceContext* c) 670     .SetShapeFn([](shape_inference::InferenceContext* c) {
671       shape_inference::ShapeHandle unused;
672       // buffer_size, seed, seed2, and count should be scalars.
673       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
674       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
675       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
676       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
677       return shape_inference::ScalarShape(c);
678     });
679 
680 REGISTER_OP("ShuffleAndRepeatDatasetV2")
681     .Input("input_dataset: variant")
682     .Input("buffer_size: int64")
683     .Input("seed: int64")
684     .Input("seed2: int64")
685     .Input("count: int64")
686     .Input("seed_generator: resource")
687     .Output("handle: variant")
688     .Attr("reshuffle_each_iteration: bool = true")
689     .Attr("output_types: list(type) >= 1")
690     .Attr("output_shapes: list(shape) >= 1")
691     .Attr("metadata: string = ''")
692     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
693                                                            "output_types"))
__anon4377504f1402(shape_inference::InferenceContext* c) 694     .SetShapeFn([](shape_inference::InferenceContext* c) {
695       shape_inference::ShapeHandle unused;
696       // buffer_size, seed, seed2, count, and seed_generator should be scalars.
697       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
698       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
699       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
700       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
701       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
702       return shape_inference::ScalarShape(c);
703     });
704 
705 REGISTER_OP("AnonymousMemoryCache")
706     .Output("handle: resource")
707     .Output("deleter: variant")
__anon4377504f1502(shape_inference::InferenceContext* c) 708     .SetShapeFn([](shape_inference::InferenceContext* c) {
709       c->set_output(0, c->Scalar());
710       c->set_output(1, c->Scalar());
711       return OkStatus();
712     });
713 
714 REGISTER_OP("DeleteMemoryCache")
715     .Input("handle: resource")
716     .Input("deleter: variant")
717     .SetShapeFn(shape_inference::NoOutputs);
718 
719 REGISTER_OP("DummyMemoryCache")
720     .Output("handle: resource")
__anon4377504f1602(shape_inference::InferenceContext* c) 721     .SetShapeFn([](shape_inference::InferenceContext* c) {
722       c->set_output(0, c->Scalar());
723       return OkStatus();
724     });
725 
726 REGISTER_OP("CacheDataset")
727     .Input("input_dataset: variant")
728     .Input("filename: string")
729     .Output("handle: variant")
730     .Attr("output_types: list(type) >= 1")
731     .Attr("output_shapes: list(shape) >= 1")
732     .Attr("metadata: string = ''")
733     // TODO(mdan): Should these use type inference instead?
734     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
735                                                            "output_types"))
__anon4377504f1702(shape_inference::InferenceContext* c) 736     .SetShapeFn([](shape_inference::InferenceContext* c) {
737       shape_inference::ShapeHandle unused;
738       // filename should be a scalar.
739       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
740       return shape_inference::ScalarShape(c);
741     });
742 
743 REGISTER_OP("CacheDatasetV2")
744     .Input("input_dataset: variant")
745     .Input("filename: string")
746     .Input("cache: resource")
747     .Output("handle: variant")
748     .Attr("output_types: list(type) >= 1")
749     .Attr("output_shapes: list(shape) >= 1")
750     .Attr("metadata: string = ''")
751     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
752                                                            "output_types"))
__anon4377504f1802(shape_inference::InferenceContext* c) 753     .SetShapeFn([](shape_inference::InferenceContext* c) {
754       shape_inference::ShapeHandle unused;
755       // filename should be a scalar.
756       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
757       // cache should be a scalar.
758       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
759       return shape_inference::ScalarShape(c);
760     });
761 
762 REGISTER_OP("TextLineDataset")
763     .Input("filenames: string")
764     .Input("compression_type: string")
765     .Input("buffer_size: int64")
766     .Attr("metadata: string = ''")
767     .Output("handle: variant")
768     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
769     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
770                                                         TFT_STRING))
__anon4377504f1902(shape_inference::InferenceContext* c) 771     .SetShapeFn([](shape_inference::InferenceContext* c) {
772       shape_inference::ShapeHandle unused;
773       // `filenames` must be a scalar or a vector.
774       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
775       // `compression_type` could only be a scalar.
776       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
777       // `buffer_size` could only be a scalar.
778       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
779       return shape_inference::ScalarShape(c);
780     });
781 
782 REGISTER_OP("FixedLengthRecordDataset")
783     .Input("filenames: string")
784     .Input("header_bytes: int64")
785     .Input("record_bytes: int64")
786     .Input("footer_bytes: int64")
787     .Input("buffer_size: int64")
788     .Attr("metadata: string = ''")
789     .Output("handle: variant")
790     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
791     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
792                                                         TFT_STRING))
__anon4377504f1a02(shape_inference::InferenceContext* c) 793     .SetShapeFn([](shape_inference::InferenceContext* c) {
794       shape_inference::ShapeHandle unused;
795       // `filenames` must be a scalar or a vector.
796       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
797       // header_bytes, record_bytes, footer_bytes, buffer_size should be
798       // scalars.
799       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
800       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
801       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
802       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
803       return shape_inference::ScalarShape(c);
804     });
805 
806 REGISTER_OP("FixedLengthRecordDatasetV2")
807     .Input("filenames: string")
808     .Input("header_bytes: int64")
809     .Input("record_bytes: int64")
810     .Input("footer_bytes: int64")
811     .Input("buffer_size: int64")
812     .Input("compression_type: string")
813     .Attr("metadata: string = ''")
814     .Output("handle: variant")
815     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
816     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
817                                                         TFT_STRING))
__anon4377504f1b02(shape_inference::InferenceContext* c) 818     .SetShapeFn([](shape_inference::InferenceContext* c) {
819       shape_inference::ShapeHandle unused;
820       // `filenames` must be a scalar or a vector.
821       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
822       // header_bytes, record_bytes, footer_bytes, buffer_size should be
823       // scalars.
824       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
825       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
826       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
827       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
828       return shape_inference::ScalarShape(c);
829     });
830 
831 REGISTER_OP("TFRecordDataset")
832     .Input("filenames: string")
833     .Input("compression_type: string")
834     .Input("buffer_size: int64")
835     .Attr("metadata: string = ''")
836     .Output("handle: variant")
837     .SetDoNotOptimize()  // TODO(b/123753214): See comment in dataset_ops.cc.
838     .SetTypeConstructor(full_type::UnaryTensorContainer(TFT_DATASET,
839                                                         TFT_STRING))
__anon4377504f1c02(shape_inference::InferenceContext* c) 840     .SetShapeFn([](shape_inference::InferenceContext* c) {
841       shape_inference::ShapeHandle unused;
842       // `filenames` must be a scalar or a vector.
843       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
844       // `compression_type` could only be a scalar.
845       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
846       // `buffer_size` could only be a scalar.
847       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
848       return shape_inference::ScalarShape(c);
849     });
850 
851 REGISTER_OP("Iterator")
852     .Output("handle: resource")
853     .Attr("shared_name: string")
854     .Attr("container: string")
855     .Attr("output_types: list(type) >= 1")
856     .Attr("output_shapes: list(shape) >= 1")
857     .SetShapeFn(shape_inference::ScalarShape);
858 
859 REGISTER_OP("IteratorV2")
860     .Output("handle: resource")
861     .Attr("shared_name: string")
862     .Attr("container: string")
863     .Attr("output_types: list(type) >= 1")
864     .Attr("output_shapes: list(shape) >= 1")
865     .SetShapeFn(shape_inference::ScalarShape);
866 
867 REGISTER_OP("AnonymousIterator")
868     .Output("handle: resource")
869     .Attr("output_types: list(type) >= 1")
870     .Attr("output_shapes: list(shape) >= 1")
871     .SetShapeFn(shape_inference::ScalarShape);
872 
873 REGISTER_OP("AnonymousIteratorV2")
874     .Output("handle: resource")
875     .Output("deleter: variant")
876     .Attr("output_types: list(type) >= 1")
877     .Attr("output_shapes: list(shape) >= 1")
__anon4377504f1d02(shape_inference::InferenceContext* c) 878     .SetShapeFn([](shape_inference::InferenceContext* c) {
879       c->set_output(0, c->Scalar());
880       c->set_output(1, c->Scalar());
881       return OkStatus();
882     });
883 
884 REGISTER_OP("AnonymousIteratorV3")
885     .Output("handle: resource")
886     .Attr("output_types: list(type) >= 1")
887     .Attr("output_shapes: list(shape) >= 1")
__anon4377504f1e02(shape_inference::InferenceContext* c) 888     .SetShapeFn([](shape_inference::InferenceContext* c) {
889       c->set_output(0, c->Scalar());
890       return OkStatus();
891     });
892 
893 REGISTER_OP("DeleteIterator")
894     .Input("handle: resource")
895     .Input("deleter: variant")
896     .SetShapeFn(shape_inference::NoOutputs);
897 
898 REGISTER_OP("DeleteMultiDeviceIterator")
899     .Input("multi_device_iterator: resource")
900     .Input("iterators: N * resource")
901     .Input("deleter: variant")
902     .Attr("N: int >= 0")
903     .SetShapeFn(shape_inference::NoOutputs);
904 
905 REGISTER_OP("MakeIterator")
906     .Input("dataset: variant")
907     .Input("iterator: resource")
908     .SetTypeConstructor(full_type::NoOutputs())
909     .SetReverseTypeFn(1, full_type::MapCovariant(TFT_DATASET, TFT_ITERATOR, 0))
910     .SetShapeFn(shape_inference::NoOutputs);
911 
912 REGISTER_OP("OneShotIterator")
913     .Output("handle: resource")
914     .Attr("dataset_factory: func")
915     .Attr("output_types: list(type) >= 1")
916     .Attr("output_shapes: list(shape) >= 1")
917     .Attr("container: string = ''")
918     .Attr("shared_name: string = ''")
919     .SetIsStateful()
920     .SetShapeFn(shape_inference::ScalarShape);
921 
922 REGISTER_OP("IteratorGetNext")
923     .Input("iterator: resource")
924     .Output("components: output_types")
925     .Attr("output_types: list(type) >= 1")
926     .Attr("output_shapes: list(shape) >= 1")
927     .SetShapeFn(shape_inference::DatasetIteratorShape);
928 
929 REGISTER_OP("IteratorGetNextSync")
930     .Input("iterator: resource")
931     .Output("components: output_types")
932     .Attr("output_types: list(type) >= 1")
933     .Attr("output_shapes: list(shape) >= 1")
934     .SetShapeFn(shape_inference::DatasetIteratorShape);
935 
936 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
937 // implement a mechanism to determine whether `dataset` has a side-effect
938 // and use it to decide whether to use a stateless or stateful version of this
939 // op.
940 REGISTER_OP("DatasetToSingleElement")
941     .Input("dataset: variant")
942     .Output("components: output_types")
943     .Attr("output_types: list(type) >= 1")
944     .Attr("output_shapes: list(shape) >= 1")
945     .Attr("metadata: string = ''")
946     .SetIsStateful()
947     .SetShapeFn(shape_inference::DatasetIteratorShape);
948 
949 // TODO(b/124308596): Instead of conservatively marking this op as stateful,
950 // implement a mechanism to determine whether `dataset` has a side-effect
951 // and use it to decide whether to use a stateless or stateful version of this
952 // op.
953 REGISTER_OP("ReduceDataset")
954     .Input("input_dataset: variant")
955     .Input("initial_state: Tstate")
956     .Input("other_arguments: Targuments")
957     .Output("components: output_types")
958     .Attr("f: func")
959     .Attr("Tstate: list(type) >= 1")
960     .Attr("Targuments: list(type) >= 0")
961     .Attr("output_types: list(type) >= 1")
962     .Attr("output_shapes: list(shape) >= 1")
963     .Attr("use_inter_op_parallelism: bool = true")
964     .Attr("metadata: string = ''")
965     .SetIsStateful()
966     .SetShapeFn(shape_inference::DatasetIteratorShape);
967 
968 REGISTER_OP("IteratorToStringHandle")
969     .Input("resource_handle: resource")
970     .Output("string_handle: string")
971     .SetShapeFn(shape_inference::ScalarShape);
972 
973 REGISTER_OP("IteratorFromStringHandle")
974     .Input("string_handle: string")
975     .Output("resource_handle: resource")
976     .Attr("output_types: list(type) >= 0 = []")
977     .Attr("output_shapes: list(shape) >= 0 = []")
978     .SetShapeFn(shape_inference::ScalarShape);
979 
980 REGISTER_OP("IteratorFromStringHandleV2")
981     .Input("string_handle: string")
982     .Output("resource_handle: resource")
983     .Attr("output_types: list(type) >= 0 = []")
984     .Attr("output_shapes: list(shape) >= 0 = []")
985     .SetShapeFn(shape_inference::ScalarShape);
986 
987 REGISTER_OP("SerializeIterator")
988     .Input("resource_handle: resource")
989     .Attr("external_state_policy: int = 0")
990     .Output("serialized: variant")
__anon4377504f1f02(shape_inference::InferenceContext* c) 991     .SetShapeFn([](shape_inference::InferenceContext* c) {
992       c->set_output(0, c->Vector(c->UnknownDim()));
993       return OkStatus();
994     });
995 
996 REGISTER_OP("DeserializeIterator")
997     .Input("resource_handle: resource")
998     .Input("serialized: variant")
999     .SetShapeFn(shape_inference::NoOutputs);
1000 
1001 REGISTER_OP("DatasetToGraph")
1002     .Input("input_dataset: variant")
1003     .Attr("stateful_whitelist: list(string) >= 0 = []")
1004     .Attr("allow_stateful: bool = false")
1005     .Attr("strip_device_assignment: bool = false")
1006     .Output("graph: string")
1007     .SetShapeFn(shape_inference::ScalarShape);
1008 
1009 REGISTER_OP("DatasetToGraphV2")
1010     .Input("input_dataset: variant")
1011     .Attr("external_state_policy: int = 0")
1012     .Attr("strip_device_assignment: bool = false")
1013     .Output("graph: string")
1014     .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0))
1015     .SetShapeFn(shape_inference::ScalarShape);
1016 
1017 REGISTER_OP("OptimizeDataset")
1018     .Input("input_dataset: variant")
1019     .Input("optimizations: string")
1020     .Output("handle: variant")
1021     .Attr("output_types: list(type) >= 1")
1022     .Attr("output_shapes: list(shape) >= 1")
1023     .Attr("optimization_configs: list(string) = []")
1024     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1025                                                            "output_types"))
1026     .SetShapeFn(shape_inference::ScalarShape);
1027 
1028 REGISTER_OP("OptimizeDatasetV2")
1029     .Input("input_dataset: variant")
1030     .Input("optimizations_enabled: string")
1031     .Input("optimizations_disabled: string")
1032     .Input("optimizations_default: string")
1033     .Output("handle: variant")
1034     .Attr("output_types: list(type) >= 1")
1035     .Attr("output_shapes: list(shape) >= 1")
1036     .Attr("optimization_configs: list(string) = []")
1037     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1038                                                            "output_types"))
1039     .SetShapeFn(shape_inference::ScalarShape);
1040 
1041 REGISTER_OP("OptionalFromValue")
1042     .Input("components: Toutput_types")
1043     .Output("optional: variant")
1044     .Attr("Toutput_types: list(type) >= 1")
1045     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL,
1046                                                            "Toutput_types"))
__anon4377504f2002(shape_inference::InferenceContext* c) 1047     .SetShapeFn([](shape_inference::InferenceContext* c) {
1048       std::vector<DataType> dtypes;
1049       TF_RETURN_IF_ERROR(c->GetAttr("Toutput_types", &dtypes));
1050       c->set_output(0, c->Scalar());
1051       std::vector<shape_inference::ShapeAndType> shapes_and_types;
1052       shapes_and_types.reserve(c->num_inputs());
1053       const FullTypeDef& ret_types = c->ret_types();
1054       for (int i = 0; i < c->num_inputs(); ++i) {
1055         // TODO(mdan): output_type(i) == optional is incorrect.
1056         // "Optional" is the type of the whole container, not of individual
1057         // elements.
1058         //
1059         // Why ret_types.args(0) and not args(i) --
1060         // For example if Toutput_types is (int32, float32), then
1061         // ret_types.args[0] (i.e. the 0th output) is
1062         // Optional[Record[Tensor[int32, s1], Tensor[float32, s2]]]
1063         // set_output_handle_shapes_and_types tracks the same thing, but in
1064         // a transposed way:
1065         // {ShapeAndType(in32, s1, Optional), ShapeAndType(in32, s2, Optional)}
1066         // That should be corrected in the future (see todo above).
1067         shapes_and_types.emplace_back(c->input(i), dtypes[i],
1068                                       ret_types.args(0));
1069       }
1070       c->set_output_handle_shapes_and_types(0, shapes_and_types);
1071       return OkStatus();
1072     });
1073 
1074 REGISTER_OP("OptionalNone")
1075     .Output("optional: variant")
1076     .SetShapeFn(shape_inference::ScalarShape);
1077 
1078 REGISTER_OP("OptionalHasValue")
1079     .Input("optional: variant")
1080     .Output("has_value: bool")
1081     .SetShapeFn(shape_inference::ScalarShape);
1082 
1083 REGISTER_OP("OptionalGetValue")
1084     .Input("optional: variant")
1085     .Output("components: output_types")
1086     .Attr("output_types: list(type) >= 1")
1087     .Attr("output_shapes: list(shape) >= 1")
1088     .SetShapeFn(shape_inference::DatasetIteratorShape);
1089 
1090 REGISTER_OP("IteratorGetNextAsOptional")
1091     .Input("iterator: resource")
1092     .Output("optional: variant")
1093     .Attr("output_types: list(type) >= 1")
1094     .Attr("output_shapes: list(shape) >= 1")
1095     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_OPTIONAL,
1096                                                            "output_types"))
1097     .SetForwardTypeFn(full_type::MapCovariant(TFT_ITERATOR, TFT_OPTIONAL, 0))
1098     .SetShapeFn(shape_inference::ScalarShape);
1099 
1100 REGISTER_OP("ModelDataset")
1101     .Input("input_dataset: variant")
1102     .Output("handle: variant")
1103     .Attr("algorithm: int = 0")
1104     .Attr("cpu_budget: int = 0")
1105     .Attr("ram_budget: int = 0")
1106     .Attr("output_types: list(type) >= 1")
1107     .Attr("output_shapes: list(shape) >= 1")
1108     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1109                                                            "output_types"))
1110     .SetShapeFn(shape_inference::ScalarShape);
1111 
1112 // TODO(b/124308749): Add a stateful version of MapDefun and use it when `f`
1113 // is stateful.
1114 REGISTER_OP("MapDefun")
1115     .Input("arguments: Targuments")
1116     .Input("captured_inputs: Tcaptured")
1117     .Output("output: output_types")
1118     .Attr("Targuments: list(type) >= 1")
1119     .Attr("Tcaptured: list(type) >= 0 = []")
1120     .Attr("output_types: list(type) >= 1")
1121     .Attr("output_shapes: list(shape) >= 1")
1122     .Attr("f: func")
1123     .Attr("max_intra_op_parallelism: int = 1")
__anon4377504f2102(shape_inference::InferenceContext* c) 1124     .SetShapeFn([](shape_inference::InferenceContext* c) {
1125       std::vector<PartialTensorShape> output_shapes;
1126       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
1127       DataTypeVector t_args;
1128       TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
1129       if (output_shapes.size() != c->num_outputs()) {
1130         return errors::InvalidArgument(
1131             "`output_shapes` must be the same length as `output_types` (",
1132             output_shapes.size(), " vs. ", c->num_outputs(), ")");
1133       }
1134 
1135       int64_t dim_zero = -1;
1136       for (size_t i = 0; i < t_args.size(); ++i) {
1137         if (c->Rank(c->input(i)) == 0) {
1138           return errors::InvalidArgument(
1139               "Arguments must have rank at least 1. Input ", i,
1140               " has rank of 0.");
1141         }
1142         auto dim_handle = c->Dim(c->input(i), 0);
1143         if (c->ValueKnown(dim_handle)) {
1144           if (dim_zero == -1) {
1145             dim_zero = c->Value(dim_handle);
1146           } else if (c->Value(dim_handle) != dim_zero) {
1147             return errors::InvalidArgument(
1148                 "Arguments must have the same dimension 0.");
1149           }
1150         }
1151       }
1152 
1153       for (size_t i = 0; i < output_shapes.size(); ++i) {
1154         PartialTensorShape s({});
1155         s = s.Concatenate(dim_zero);
1156         s = s.Concatenate(output_shapes[i]);
1157         shape_inference::ShapeHandle output_shape_handle;
1158 
1159         TF_RETURN_IF_ERROR(
1160             c->MakeShapeFromPartialTensorShape(s, &output_shape_handle));
1161         c->set_output(static_cast<int>(i), output_shape_handle);
1162       }
1163       return OkStatus();
1164     });
1165 
1166 REGISTER_OP("WrapDatasetVariant")
1167     .Input("input_handle: variant")
1168     .Output("output_handle: variant")
1169     .SetShapeFn(shape_inference::ScalarShape);
1170 
1171 REGISTER_OP("UnwrapDatasetVariant")
1172     .Input("input_handle: variant")
1173     .Output("output_handle: variant")
1174     .SetShapeFn(shape_inference::ScalarShape);
1175 
1176 REGISTER_OP("AnonymousMultiDeviceIterator")
1177     .Output("handle: resource")
1178     .Output("deleter: variant")
1179     .Attr("devices: list(string) >= 1")
1180     .Attr("output_types: list(type) >= 1")
1181     .Attr("output_shapes: list(shape) >= 1")
__anon4377504f2202(shape_inference::InferenceContext* c) 1182     .SetShapeFn([](shape_inference::InferenceContext* c) {
1183       c->set_output(0, c->Scalar());
1184       c->set_output(1, c->Scalar());
1185       return OkStatus();
1186     });
1187 
1188 REGISTER_OP("AnonymousMultiDeviceIteratorV3")
1189     .Output("handle: resource")
1190     .Attr("devices: list(string) >= 1")
1191     .Attr("output_types: list(type) >= 1")
1192     .Attr("output_shapes: list(shape) >= 1")
__anon4377504f2302(shape_inference::InferenceContext* c) 1193     .SetShapeFn([](shape_inference::InferenceContext* c) {
1194       c->set_output(0, c->Scalar());
1195       return OkStatus();
1196     });
1197 
1198 REGISTER_OP("MultiDeviceIterator")
1199     .Output("handle: resource")
1200     .Attr("devices: list(string) >= 1")
1201     .Attr("shared_name: string")
1202     .Attr("container: string")
1203     .Attr("output_types: list(type) >= 1")
1204     .Attr("output_shapes: list(shape) >= 1")
1205     .SetShapeFn(shape_inference::ScalarShape);
1206 
1207 REGISTER_OP("MultiDeviceIteratorInit")
1208     .Input("dataset: variant")
1209     .Input("multi_device_iterator: resource")
1210     .Input("max_buffer_size: int64")
1211     .Output("incarnation_id: int64")
1212     .SetShapeFn(shape_inference::ScalarShape);
1213 
1214 REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
1215     .Input("multi_device_iterator: resource")
1216     .Input("shard_num: int32")
1217     .Input("incarnation_id: int64")
1218     .Output("components: output_types")
1219     .Attr("output_types: list(type) >= 1")
1220     .Attr("output_shapes: list(shape) >= 1")
1221     .SetShapeFn(shape_inference::DatasetIteratorShape);
1222 
1223 REGISTER_OP("MultiDeviceIteratorToStringHandle")
1224     .Input("multi_device_iterator: resource")
1225     .Output("string_handle: string")
1226     .SetForwardTypeFn(full_type::Encode(TFT_STRING, 0))
1227     .SetShapeFn(shape_inference::ScalarShape);
1228 
1229 REGISTER_OP("MultiDeviceIteratorFromStringHandle")
1230     .Input("string_handle: string")
1231     .Output("multi_device_iterator: resource")
1232     .Attr("output_types: list(type) >= 0 = []")
1233     .Attr("output_shapes: list(shape) >= 0 = []")
1234     .SetForwardTypeFn(full_type::Decode(TFT_STRING, 0))
1235     .SetShapeFn(shape_inference::ScalarShape);
1236 
1237 REGISTER_OP("OptionsDataset")
1238     .Input("input_dataset: variant")
1239     .Output("handle: variant")
1240     .Attr("serialized_options: string")
1241     .Attr("output_types: list(type) >= 1")
1242     .Attr("output_shapes: list(shape) >= 1")
1243     .Attr("metadata: string = ''")
1244     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1245                                                            "output_types"))
1246     .SetShapeFn(shape_inference::ScalarShape);
1247 
1248 REGISTER_OP("GetOptions")
1249     .Input("input_dataset: variant")
1250     .Output("serialized_options: string")
1251     .SetShapeFn(shape_inference::ScalarShape);
1252 
1253 REGISTER_OP("FinalizeDataset")
1254     .Input("input_dataset: variant")
1255     .Output("handle: variant")
1256     .Attr("has_captured_ref: bool = false")
1257     .Attr("output_types: list(type) >= 1")
1258     .Attr("output_shapes: list(shape) >= 1")
1259     .SetTypeConstructor(full_type::VariadicTensorContainer(TFT_DATASET,
1260                                                            "output_types"))
1261     .SetShapeFn(shape_inference::ScalarShape);
1262 
1263 }  // namespace tensorflow
1264