xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/io_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/util/saved_tensor_slice_util.h"
20 
21 namespace tensorflow {
22 
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 namespace {
28 
ScalarInputsAndOutputs(InferenceContext * c)29 Status ScalarInputsAndOutputs(InferenceContext* c) {
30   ShapeHandle unused;
31   for (int i = 0; i < c->num_inputs(); ++i) {
32     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
33   }
34   for (int i = 0; i < c->num_outputs(); ++i) {
35     c->set_output(i, c->Scalar());
36   }
37   return OkStatus();
38 }
39 
TwoElementVectorAndScalarOutputs(InferenceContext * c)40 Status TwoElementVectorAndScalarOutputs(InferenceContext* c) {
41   ShapeHandle handle;
42   DimensionHandle unused_handle;
43   for (int i = 0; i < c->num_inputs(); ++i) {
44     TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
45     TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
46   }
47   for (int i = 0; i < c->num_outputs(); ++i) {
48     c->set_output(i, c->Scalar());
49   }
50   return OkStatus();
51 }
52 
TwoElementOutput(InferenceContext * c)53 Status TwoElementOutput(InferenceContext* c) {
54   c->set_output(0, c->Vector(2));
55   return OkStatus();
56 }
57 
58 }  // namespace
59 
60 REGISTER_OP("SaveV2")
61     .Input("prefix: string")
62     .Input("tensor_names: string")
63     .Input("shape_and_slices: string")
64     .Input("tensors: dtypes")
65     .Attr("dtypes: list(type)")
66     .SetIsStateful()
__anon2a9f7d410202(InferenceContext* c) 67     .SetShapeFn([](InferenceContext* c) {
68       ShapeHandle unused;
69       ShapeHandle s;
70       DimensionHandle unused_dim;
71 
72       // Validate prefix.
73       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
74 
75       // Validate tensor_names and shapes_and_slices.
76       for (int i = 1; i <= 2; ++i) {
77         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
78         TF_RETURN_IF_ERROR(
79             c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
80       }
81       // TODO(mrry): Attempt to parse the shapes_and_slices values and use
82       // them to constrain the shape of the remaining inputs.
83       return OkStatus();
84     });
85 
86 REGISTER_OP("RestoreV2")
87     .Input("prefix: string")
88     .Input("tensor_names: string")
89     .Input("shape_and_slices: string")
90     .Output("tensors: dtypes")
91     .Attr("dtypes: list(type)")
92     .SetIsStateful()
__anon2a9f7d410302(InferenceContext* c) 93     .SetShapeFn([](InferenceContext* c) {
94       ShapeHandle shape0, shape1, shape2;
95       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
96       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
97       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
98       TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
99 
100       // Attempt to infer output shapes from its shape_and_slice input.
101       const Tensor* shape_and_slices_tensor = c->input_tensor(2);
102       if (shape_and_slices_tensor) {
103         if (shape_and_slices_tensor->dtype() != DT_STRING) {
104           return errors::InvalidArgument(
105               "Expected an input tensor of type string.");
106         }
107 
108         const auto& shape_and_slices_flat =
109             shape_and_slices_tensor->flat<tstring>();
110         if (shape_and_slices_flat.size() != c->num_outputs()) {
111           return errors::InvalidArgument(
112               "The number of shape_and_slice doesn't match tensor outputs.");
113         }
114         for (int i = 0; i < shape_and_slices_flat.size(); ++i) {
115           const string& shape_and_slice = shape_and_slices_flat(i);
116           if (shape_and_slice.empty()) {
117             c->set_output(i, c->UnknownShape());
118             continue;
119           }
120           TensorShape parsed_full_shape;
121           TensorSlice parsed_slice;
122           TensorShape parsed_slice_shape;
123           TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
124               shape_and_slice, &parsed_full_shape, &parsed_slice,
125               &parsed_slice_shape));
126           ShapeHandle shape_handle;
127           TF_RETURN_IF_ERROR(
128               c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
129           c->set_output(i, shape_handle);
130         }
131         return OkStatus();
132       } else {
133         return UnknownShape(c);
134       }
135     });
136 
137 REGISTER_OP("MergeV2Checkpoints")
138     .Input("checkpoint_prefixes: string")
139     .Input("destination_prefix: string")
140     .Attr("delete_old_dirs: bool = true")
141     .Attr("allow_missing_files: bool = false")
142     .SetIsStateful()
__anon2a9f7d410402(InferenceContext* c) 143     .SetShapeFn([](InferenceContext* c) {
144       ShapeHandle unused;
145       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
146       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
147       return OkStatus();
148     });
149 
150 REGISTER_OP("Save")
151     .Input("filename: string")
152     .Input("tensor_names: string")
153     .Input("data: T")
154     .Attr("T: list(type)")
155     .SetIsStateful()
__anon2a9f7d410502(InferenceContext* c) 156     .SetShapeFn([](InferenceContext* c) {
157       ShapeHandle unused;
158       ShapeHandle s;
159       DimensionHandle unused_dim;
160 
161       // Validate filename.
162       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
163 
164       // Validate tensor_names.
165       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s));
166       TF_RETURN_IF_ERROR(
167           c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim));
168 
169       return OkStatus();
170     });
171 
172 REGISTER_OP("SaveSlices")
173     .Input("filename: string")
174     .Input("tensor_names: string")
175     .Input("shapes_and_slices: string")
176     .Input("data: T")
177     .Attr("T: list(type)")
178     .SetIsStateful()
__anon2a9f7d410602(InferenceContext* c) 179     .SetShapeFn([](InferenceContext* c) {
180       ShapeHandle unused;
181       ShapeHandle s;
182       DimensionHandle unused_dim;
183 
184       // Validate filename.
185       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
186 
187       // Validate tensor_names and unused_shapes_and_slices.
188       for (int i = 1; i <= 2; ++i) {
189         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
190         TF_RETURN_IF_ERROR(
191             c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
192       }
193       // TODO(mrry): Attempt to parse the shapes_and_slices values and use
194       // them to constrain the shape of the remaining inputs.
195       return OkStatus();
196     });
197 
198 REGISTER_OP("Restore")
199     .Input("file_pattern: string")
200     .Input("tensor_name: string")
201     .Output("tensor: dt")
202     .Attr("dt: type")
203     .Attr("preferred_shard: int = -1")
204     .SetIsStateful()
__anon2a9f7d410702(InferenceContext* c) 205     .SetShapeFn([](InferenceContext* c) {
206       ShapeHandle unused;
207       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
208       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
209       c->set_output(0, c->UnknownShape());
210       return OkStatus();
211     });
212 
213 REGISTER_OP("RestoreSlice")
214     .Input("file_pattern: string")
215     .Input("tensor_name: string")
216     .Input("shape_and_slice: string")
217     .Output("tensor: dt")
218     .Attr("dt: type")
219     .Attr("preferred_shard: int = -1")
220     .SetIsStateful()
__anon2a9f7d410802(InferenceContext* c) 221     .SetShapeFn([](InferenceContext* c) {
222       ShapeHandle unused;
223       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
224       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
225       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
226 
227       // Attempt to infer output shapes from its shape_and_slice input.
228       const Tensor* shape_and_slices_tensor = c->input_tensor(2);
229       if (shape_and_slices_tensor) {
230         const auto& shape_and_slice =
231             shape_and_slices_tensor->flat<tstring>()(0);
232         if (shape_and_slice.empty()) {
233           c->set_output(0, c->UnknownShape());
234         } else {
235           TensorShape parsed_full_shape;
236           TensorSlice parsed_slice;
237           TensorShape parsed_slice_shape;
238           TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
239               shape_and_slice, &parsed_full_shape, &parsed_slice,
240               &parsed_slice_shape));
241           ShapeHandle shape_handle;
242           TF_RETURN_IF_ERROR(
243               c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
244           c->set_output(0, shape_handle);
245         }
246       } else {
247         c->set_output(0, c->UnknownShape());
248       }
249       return OkStatus();
250     });
251 
252 REGISTER_OP("ShardedFilename")
253     .Input("basename: string")
254     .Input("shard: int32")
255     .Input("num_shards: int32")
256     .Output("filename: string")
257     .SetShapeFn(ScalarInputsAndOutputs);
258 
259 REGISTER_OP("ShardedFilespec")
260     .Input("basename: string")
261     .Input("num_shards: int32")
262     .Output("filename: string")
263     .SetShapeFn(ScalarInputsAndOutputs);
264 
265 // Reader source ops ----------------------------------------------------------
266 
267 REGISTER_OP("WholeFileReader")
268     .Output("reader_handle: Ref(string)")
269     .Attr("container: string = ''")
270     .Attr("shared_name: string = ''")
271     .SetIsStateful()
272     .SetShapeFn(TwoElementOutput);
273 
274 REGISTER_OP("WholeFileReaderV2")
275     .Output("reader_handle: resource")
276     .Attr("container: string = ''")
277     .Attr("shared_name: string = ''")
278     .SetIsStateful()
279     .SetShapeFn(shape_inference::ScalarShape);
280 
281 REGISTER_OP("TextLineReader")
282     .Output("reader_handle: Ref(string)")
283     .Attr("skip_header_lines: int = 0")
284     .Attr("container: string = ''")
285     .Attr("shared_name: string = ''")
286     .SetIsStateful()
287     .SetShapeFn(TwoElementOutput)
288     .Deprecated(26, "Use TextLineReaderV2");
289 
290 REGISTER_OP("TextLineReaderV2")
291     .Output("reader_handle: resource")
292     .Attr("skip_header_lines: int = 0")
293     .Attr("container: string = ''")
294     .Attr("shared_name: string = ''")
295     .SetIsStateful()
296     .SetShapeFn(shape_inference::ScalarShape);
297 
298 REGISTER_OP("FixedLengthRecordReader")
299     .Output("reader_handle: Ref(string)")
300     .Attr("header_bytes: int = 0")
301     .Attr("record_bytes: int")
302     .Attr("footer_bytes: int = 0")
303     .Attr("hop_bytes: int = 0")
304     .Attr("container: string = ''")
305     .Attr("shared_name: string = ''")
306     .SetIsStateful()
307     .SetShapeFn(TwoElementOutput)
308     .Deprecated(26, "Use FixedLengthRecordReaderV2");
309 
310 REGISTER_OP("FixedLengthRecordReaderV2")
311     .Output("reader_handle: resource")
312     .Attr("header_bytes: int = 0")
313     .Attr("record_bytes: int")
314     .Attr("footer_bytes: int = 0")
315     .Attr("hop_bytes: int = 0")
316     .Attr("container: string = ''")
317     .Attr("shared_name: string = ''")
318     .Attr("encoding: string = ''")
319     .SetIsStateful()
320     .SetShapeFn(shape_inference::ScalarShape);
321 
322 REGISTER_OP("TFRecordReader")
323     .Output("reader_handle: Ref(string)")
324     .Attr("container: string = ''")
325     .Attr("shared_name: string = ''")
326     .Attr("compression_type: string = ''")
327     .SetIsStateful()
328     .SetShapeFn(TwoElementOutput)
329     .Deprecated(26, "Use TFRecordReaderV2");
330 
331 REGISTER_OP("TFRecordReaderV2")
332     .Output("reader_handle: resource")
333     .Attr("container: string = ''")
334     .Attr("shared_name: string = ''")
335     .Attr("compression_type: string = ''")
336     .SetIsStateful()
337     .SetShapeFn(shape_inference::ScalarShape);
338 
339 REGISTER_OP("LMDBReader")
340     .Output("reader_handle: Ref(string)")
341     .Attr("container: string = ''")
342     .Attr("shared_name: string = ''")
343     .SetIsStateful()
344     .SetShapeFn(TwoElementOutput);
345 
346 REGISTER_OP("IdentityReader")
347     .Output("reader_handle: Ref(string)")
348     .Attr("container: string = ''")
349     .Attr("shared_name: string = ''")
350     .SetIsStateful()
351     .SetShapeFn(TwoElementOutput)
352     .Deprecated(26, "Use IdentityReaderV2");
353 
354 REGISTER_OP("IdentityReaderV2")
355     .Output("reader_handle: resource")
356     .Attr("container: string = ''")
357     .Attr("shared_name: string = ''")
358     .SetIsStateful()
359     .SetShapeFn(shape_inference::ScalarShape);
360 
361 // Ops that operate on Readers ------------------------------------------------
362 
363 REGISTER_OP("ReaderRead")
364     .Input("reader_handle: Ref(string)")
365     .Input("queue_handle: Ref(string)")
366     .Output("key: string")
367     .Output("value: string")
368     .SetShapeFn(TwoElementVectorAndScalarOutputs);
369 
370 REGISTER_OP("ReaderReadV2")
371     .Input("reader_handle: resource")
372     .Input("queue_handle: resource")
373     .Output("key: string")
374     .Output("value: string")
375     .SetShapeFn(ScalarInputsAndOutputs);
376 
377 REGISTER_OP("ReaderReadUpTo")
378     .Input("reader_handle: Ref(string)")
379     .Input("queue_handle: Ref(string)")
380     .Input("num_records: int64")
381     .Output("keys: string")
382     .Output("values: string")
__anon2a9f7d410902(InferenceContext* c) 383     .SetShapeFn([](InferenceContext* c) {
384       ShapeHandle unused;
385       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
386       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
387       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
388       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
389       c->set_output(0, out);
390       c->set_output(1, out);
391       return OkStatus();
392     });
393 
394 REGISTER_OP("ReaderReadUpToV2")
395     .Input("reader_handle: resource")
396     .Input("queue_handle: resource")
397     .Input("num_records: int64")
398     .Output("keys: string")
399     .Output("values: string")
__anon2a9f7d410a02(InferenceContext* c) 400     .SetShapeFn([](InferenceContext* c) {
401       ShapeHandle unused;
402       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
403       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
404       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
405       ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
406       c->set_output(0, out);
407       c->set_output(1, out);
408       return OkStatus();
409     });
410 
411 REGISTER_OP("ReaderNumRecordsProduced")
412     .Input("reader_handle: Ref(string)")
413     .Output("records_produced: int64")
414     .SetShapeFn(TwoElementVectorAndScalarOutputs);
415 
416 REGISTER_OP("ReaderNumRecordsProducedV2")
417     .Input("reader_handle: resource")
418     .Output("records_produced: int64")
419     .SetShapeFn(ScalarInputsAndOutputs);
420 
421 REGISTER_OP("ReaderNumWorkUnitsCompleted")
422     .Input("reader_handle: Ref(string)")
423     .Output("units_completed: int64")
424     .SetShapeFn(TwoElementVectorAndScalarOutputs);
425 
426 REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
427     .Input("reader_handle: resource")
428     .Output("units_completed: int64")
429     .SetShapeFn(ScalarInputsAndOutputs);
430 
431 REGISTER_OP("ReaderSerializeState")
432     .Input("reader_handle: Ref(string)")
433     .Output("state: string")
434     .SetShapeFn(TwoElementVectorAndScalarOutputs);
435 
436 REGISTER_OP("ReaderSerializeStateV2")
437     .Input("reader_handle: resource")
438     .Output("state: string")
439     .SetShapeFn(ScalarInputsAndOutputs);
440 
441 REGISTER_OP("ReaderRestoreState")
442     .Input("reader_handle: Ref(string)")
443     .Input("state: string")
__anon2a9f7d410b02(InferenceContext* c) 444     .SetShapeFn([](InferenceContext* c) {
445       ShapeHandle unused;
446       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
447       DimensionHandle unused_handle;
448       TF_RETURN_IF_ERROR(
449           c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle));
450 
451       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
452       return OkStatus();
453     });
454 
455 REGISTER_OP("ReaderRestoreStateV2")
456     .Input("reader_handle: resource")
457     .Input("state: string")
__anon2a9f7d410c02(InferenceContext* c) 458     .SetShapeFn([](InferenceContext* c) {
459       ShapeHandle unused;
460       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
461       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
462       return OkStatus();
463     });
464 
465 REGISTER_OP("ReaderReset")
466     .Input("reader_handle: Ref(string)")
467     .SetShapeFn(TwoElementVectorAndScalarOutputs);
468 
469 REGISTER_OP("ReaderResetV2")
470     .Input("reader_handle: resource")
471     .SetShapeFn(ScalarInputsAndOutputs);
472 
473 // Other input Ops ----------------------------------------------------------
474 
475 REGISTER_OP("ReadFile")
476     .Input("filename: string")
477     .Output("contents: string")
478     .SetShapeFn(ScalarInputsAndOutputs);
479 
480 REGISTER_OP("WriteFile")
481     .Input("filename: string")
482     .Input("contents: string")
483     .SetIsStateful()
__anon2a9f7d410d02(InferenceContext* c) 484     .SetShapeFn([](InferenceContext* c) {
485       ShapeHandle unused;
486       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
487       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
488       return OkStatus();
489     });
490 
491 REGISTER_OP("MatchingFiles")
492     .Input("pattern: string")
493     .Output("filenames: string")
__anon2a9f7d410e02(InferenceContext* c) 494     .SetShapeFn([](InferenceContext* c) {
495       ShapeHandle unused;
496       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
497       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
498       return OkStatus();
499     });
500 
501 }  // namespace tensorflow
502