1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/util/saved_tensor_slice_util.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 namespace {
28
ScalarInputsAndOutputs(InferenceContext * c)29 Status ScalarInputsAndOutputs(InferenceContext* c) {
30 ShapeHandle unused;
31 for (int i = 0; i < c->num_inputs(); ++i) {
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
33 }
34 for (int i = 0; i < c->num_outputs(); ++i) {
35 c->set_output(i, c->Scalar());
36 }
37 return OkStatus();
38 }
39
TwoElementVectorAndScalarOutputs(InferenceContext * c)40 Status TwoElementVectorAndScalarOutputs(InferenceContext* c) {
41 ShapeHandle handle;
42 DimensionHandle unused_handle;
43 for (int i = 0; i < c->num_inputs(); ++i) {
44 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
45 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
46 }
47 for (int i = 0; i < c->num_outputs(); ++i) {
48 c->set_output(i, c->Scalar());
49 }
50 return OkStatus();
51 }
52
TwoElementOutput(InferenceContext * c)53 Status TwoElementOutput(InferenceContext* c) {
54 c->set_output(0, c->Vector(2));
55 return OkStatus();
56 }
57
58 } // namespace
59
60 REGISTER_OP("SaveV2")
61 .Input("prefix: string")
62 .Input("tensor_names: string")
63 .Input("shape_and_slices: string")
64 .Input("tensors: dtypes")
65 .Attr("dtypes: list(type)")
66 .SetIsStateful()
__anon2a9f7d410202(InferenceContext* c) 67 .SetShapeFn([](InferenceContext* c) {
68 ShapeHandle unused;
69 ShapeHandle s;
70 DimensionHandle unused_dim;
71
72 // Validate prefix.
73 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
74
75 // Validate tensor_names and shapes_and_slices.
76 for (int i = 1; i <= 2; ++i) {
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
78 TF_RETURN_IF_ERROR(
79 c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
80 }
81 // TODO(mrry): Attempt to parse the shapes_and_slices values and use
82 // them to constrain the shape of the remaining inputs.
83 return OkStatus();
84 });
85
86 REGISTER_OP("RestoreV2")
87 .Input("prefix: string")
88 .Input("tensor_names: string")
89 .Input("shape_and_slices: string")
90 .Output("tensors: dtypes")
91 .Attr("dtypes: list(type)")
92 .SetIsStateful()
__anon2a9f7d410302(InferenceContext* c) 93 .SetShapeFn([](InferenceContext* c) {
94 ShapeHandle shape0, shape1, shape2;
95 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
96 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
97 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
98 TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
99
100 // Attempt to infer output shapes from its shape_and_slice input.
101 const Tensor* shape_and_slices_tensor = c->input_tensor(2);
102 if (shape_and_slices_tensor) {
103 if (shape_and_slices_tensor->dtype() != DT_STRING) {
104 return errors::InvalidArgument(
105 "Expected an input tensor of type string.");
106 }
107
108 const auto& shape_and_slices_flat =
109 shape_and_slices_tensor->flat<tstring>();
110 if (shape_and_slices_flat.size() != c->num_outputs()) {
111 return errors::InvalidArgument(
112 "The number of shape_and_slice doesn't match tensor outputs.");
113 }
114 for (int i = 0; i < shape_and_slices_flat.size(); ++i) {
115 const string& shape_and_slice = shape_and_slices_flat(i);
116 if (shape_and_slice.empty()) {
117 c->set_output(i, c->UnknownShape());
118 continue;
119 }
120 TensorShape parsed_full_shape;
121 TensorSlice parsed_slice;
122 TensorShape parsed_slice_shape;
123 TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
124 shape_and_slice, &parsed_full_shape, &parsed_slice,
125 &parsed_slice_shape));
126 ShapeHandle shape_handle;
127 TF_RETURN_IF_ERROR(
128 c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
129 c->set_output(i, shape_handle);
130 }
131 return OkStatus();
132 } else {
133 return UnknownShape(c);
134 }
135 });
136
137 REGISTER_OP("MergeV2Checkpoints")
138 .Input("checkpoint_prefixes: string")
139 .Input("destination_prefix: string")
140 .Attr("delete_old_dirs: bool = true")
141 .Attr("allow_missing_files: bool = false")
142 .SetIsStateful()
__anon2a9f7d410402(InferenceContext* c) 143 .SetShapeFn([](InferenceContext* c) {
144 ShapeHandle unused;
145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
147 return OkStatus();
148 });
149
150 REGISTER_OP("Save")
151 .Input("filename: string")
152 .Input("tensor_names: string")
153 .Input("data: T")
154 .Attr("T: list(type)")
155 .SetIsStateful()
__anon2a9f7d410502(InferenceContext* c) 156 .SetShapeFn([](InferenceContext* c) {
157 ShapeHandle unused;
158 ShapeHandle s;
159 DimensionHandle unused_dim;
160
161 // Validate filename.
162 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
163
164 // Validate tensor_names.
165 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s));
166 TF_RETURN_IF_ERROR(
167 c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim));
168
169 return OkStatus();
170 });
171
172 REGISTER_OP("SaveSlices")
173 .Input("filename: string")
174 .Input("tensor_names: string")
175 .Input("shapes_and_slices: string")
176 .Input("data: T")
177 .Attr("T: list(type)")
178 .SetIsStateful()
__anon2a9f7d410602(InferenceContext* c) 179 .SetShapeFn([](InferenceContext* c) {
180 ShapeHandle unused;
181 ShapeHandle s;
182 DimensionHandle unused_dim;
183
184 // Validate filename.
185 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
186
187 // Validate tensor_names and unused_shapes_and_slices.
188 for (int i = 1; i <= 2; ++i) {
189 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
190 TF_RETURN_IF_ERROR(
191 c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
192 }
193 // TODO(mrry): Attempt to parse the shapes_and_slices values and use
194 // them to constrain the shape of the remaining inputs.
195 return OkStatus();
196 });
197
198 REGISTER_OP("Restore")
199 .Input("file_pattern: string")
200 .Input("tensor_name: string")
201 .Output("tensor: dt")
202 .Attr("dt: type")
203 .Attr("preferred_shard: int = -1")
204 .SetIsStateful()
__anon2a9f7d410702(InferenceContext* c) 205 .SetShapeFn([](InferenceContext* c) {
206 ShapeHandle unused;
207 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
208 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
209 c->set_output(0, c->UnknownShape());
210 return OkStatus();
211 });
212
213 REGISTER_OP("RestoreSlice")
214 .Input("file_pattern: string")
215 .Input("tensor_name: string")
216 .Input("shape_and_slice: string")
217 .Output("tensor: dt")
218 .Attr("dt: type")
219 .Attr("preferred_shard: int = -1")
220 .SetIsStateful()
__anon2a9f7d410802(InferenceContext* c) 221 .SetShapeFn([](InferenceContext* c) {
222 ShapeHandle unused;
223 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
224 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
225 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
226
227 // Attempt to infer output shapes from its shape_and_slice input.
228 const Tensor* shape_and_slices_tensor = c->input_tensor(2);
229 if (shape_and_slices_tensor) {
230 const auto& shape_and_slice =
231 shape_and_slices_tensor->flat<tstring>()(0);
232 if (shape_and_slice.empty()) {
233 c->set_output(0, c->UnknownShape());
234 } else {
235 TensorShape parsed_full_shape;
236 TensorSlice parsed_slice;
237 TensorShape parsed_slice_shape;
238 TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
239 shape_and_slice, &parsed_full_shape, &parsed_slice,
240 &parsed_slice_shape));
241 ShapeHandle shape_handle;
242 TF_RETURN_IF_ERROR(
243 c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
244 c->set_output(0, shape_handle);
245 }
246 } else {
247 c->set_output(0, c->UnknownShape());
248 }
249 return OkStatus();
250 });
251
252 REGISTER_OP("ShardedFilename")
253 .Input("basename: string")
254 .Input("shard: int32")
255 .Input("num_shards: int32")
256 .Output("filename: string")
257 .SetShapeFn(ScalarInputsAndOutputs);
258
259 REGISTER_OP("ShardedFilespec")
260 .Input("basename: string")
261 .Input("num_shards: int32")
262 .Output("filename: string")
263 .SetShapeFn(ScalarInputsAndOutputs);
264
265 // Reader source ops ----------------------------------------------------------
266
267 REGISTER_OP("WholeFileReader")
268 .Output("reader_handle: Ref(string)")
269 .Attr("container: string = ''")
270 .Attr("shared_name: string = ''")
271 .SetIsStateful()
272 .SetShapeFn(TwoElementOutput);
273
274 REGISTER_OP("WholeFileReaderV2")
275 .Output("reader_handle: resource")
276 .Attr("container: string = ''")
277 .Attr("shared_name: string = ''")
278 .SetIsStateful()
279 .SetShapeFn(shape_inference::ScalarShape);
280
281 REGISTER_OP("TextLineReader")
282 .Output("reader_handle: Ref(string)")
283 .Attr("skip_header_lines: int = 0")
284 .Attr("container: string = ''")
285 .Attr("shared_name: string = ''")
286 .SetIsStateful()
287 .SetShapeFn(TwoElementOutput)
288 .Deprecated(26, "Use TextLineReaderV2");
289
290 REGISTER_OP("TextLineReaderV2")
291 .Output("reader_handle: resource")
292 .Attr("skip_header_lines: int = 0")
293 .Attr("container: string = ''")
294 .Attr("shared_name: string = ''")
295 .SetIsStateful()
296 .SetShapeFn(shape_inference::ScalarShape);
297
298 REGISTER_OP("FixedLengthRecordReader")
299 .Output("reader_handle: Ref(string)")
300 .Attr("header_bytes: int = 0")
301 .Attr("record_bytes: int")
302 .Attr("footer_bytes: int = 0")
303 .Attr("hop_bytes: int = 0")
304 .Attr("container: string = ''")
305 .Attr("shared_name: string = ''")
306 .SetIsStateful()
307 .SetShapeFn(TwoElementOutput)
308 .Deprecated(26, "Use FixedLengthRecordReaderV2");
309
310 REGISTER_OP("FixedLengthRecordReaderV2")
311 .Output("reader_handle: resource")
312 .Attr("header_bytes: int = 0")
313 .Attr("record_bytes: int")
314 .Attr("footer_bytes: int = 0")
315 .Attr("hop_bytes: int = 0")
316 .Attr("container: string = ''")
317 .Attr("shared_name: string = ''")
318 .Attr("encoding: string = ''")
319 .SetIsStateful()
320 .SetShapeFn(shape_inference::ScalarShape);
321
322 REGISTER_OP("TFRecordReader")
323 .Output("reader_handle: Ref(string)")
324 .Attr("container: string = ''")
325 .Attr("shared_name: string = ''")
326 .Attr("compression_type: string = ''")
327 .SetIsStateful()
328 .SetShapeFn(TwoElementOutput)
329 .Deprecated(26, "Use TFRecordReaderV2");
330
331 REGISTER_OP("TFRecordReaderV2")
332 .Output("reader_handle: resource")
333 .Attr("container: string = ''")
334 .Attr("shared_name: string = ''")
335 .Attr("compression_type: string = ''")
336 .SetIsStateful()
337 .SetShapeFn(shape_inference::ScalarShape);
338
339 REGISTER_OP("LMDBReader")
340 .Output("reader_handle: Ref(string)")
341 .Attr("container: string = ''")
342 .Attr("shared_name: string = ''")
343 .SetIsStateful()
344 .SetShapeFn(TwoElementOutput);
345
346 REGISTER_OP("IdentityReader")
347 .Output("reader_handle: Ref(string)")
348 .Attr("container: string = ''")
349 .Attr("shared_name: string = ''")
350 .SetIsStateful()
351 .SetShapeFn(TwoElementOutput)
352 .Deprecated(26, "Use IdentityReaderV2");
353
354 REGISTER_OP("IdentityReaderV2")
355 .Output("reader_handle: resource")
356 .Attr("container: string = ''")
357 .Attr("shared_name: string = ''")
358 .SetIsStateful()
359 .SetShapeFn(shape_inference::ScalarShape);
360
361 // Ops that operate on Readers ------------------------------------------------
362
363 REGISTER_OP("ReaderRead")
364 .Input("reader_handle: Ref(string)")
365 .Input("queue_handle: Ref(string)")
366 .Output("key: string")
367 .Output("value: string")
368 .SetShapeFn(TwoElementVectorAndScalarOutputs);
369
370 REGISTER_OP("ReaderReadV2")
371 .Input("reader_handle: resource")
372 .Input("queue_handle: resource")
373 .Output("key: string")
374 .Output("value: string")
375 .SetShapeFn(ScalarInputsAndOutputs);
376
377 REGISTER_OP("ReaderReadUpTo")
378 .Input("reader_handle: Ref(string)")
379 .Input("queue_handle: Ref(string)")
380 .Input("num_records: int64")
381 .Output("keys: string")
382 .Output("values: string")
__anon2a9f7d410902(InferenceContext* c) 383 .SetShapeFn([](InferenceContext* c) {
384 ShapeHandle unused;
385 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
386 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
387 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
388 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
389 c->set_output(0, out);
390 c->set_output(1, out);
391 return OkStatus();
392 });
393
394 REGISTER_OP("ReaderReadUpToV2")
395 .Input("reader_handle: resource")
396 .Input("queue_handle: resource")
397 .Input("num_records: int64")
398 .Output("keys: string")
399 .Output("values: string")
__anon2a9f7d410a02(InferenceContext* c) 400 .SetShapeFn([](InferenceContext* c) {
401 ShapeHandle unused;
402 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
403 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
404 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
405 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
406 c->set_output(0, out);
407 c->set_output(1, out);
408 return OkStatus();
409 });
410
411 REGISTER_OP("ReaderNumRecordsProduced")
412 .Input("reader_handle: Ref(string)")
413 .Output("records_produced: int64")
414 .SetShapeFn(TwoElementVectorAndScalarOutputs);
415
416 REGISTER_OP("ReaderNumRecordsProducedV2")
417 .Input("reader_handle: resource")
418 .Output("records_produced: int64")
419 .SetShapeFn(ScalarInputsAndOutputs);
420
421 REGISTER_OP("ReaderNumWorkUnitsCompleted")
422 .Input("reader_handle: Ref(string)")
423 .Output("units_completed: int64")
424 .SetShapeFn(TwoElementVectorAndScalarOutputs);
425
426 REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
427 .Input("reader_handle: resource")
428 .Output("units_completed: int64")
429 .SetShapeFn(ScalarInputsAndOutputs);
430
431 REGISTER_OP("ReaderSerializeState")
432 .Input("reader_handle: Ref(string)")
433 .Output("state: string")
434 .SetShapeFn(TwoElementVectorAndScalarOutputs);
435
436 REGISTER_OP("ReaderSerializeStateV2")
437 .Input("reader_handle: resource")
438 .Output("state: string")
439 .SetShapeFn(ScalarInputsAndOutputs);
440
441 REGISTER_OP("ReaderRestoreState")
442 .Input("reader_handle: Ref(string)")
443 .Input("state: string")
__anon2a9f7d410b02(InferenceContext* c) 444 .SetShapeFn([](InferenceContext* c) {
445 ShapeHandle unused;
446 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
447 DimensionHandle unused_handle;
448 TF_RETURN_IF_ERROR(
449 c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle));
450
451 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
452 return OkStatus();
453 });
454
455 REGISTER_OP("ReaderRestoreStateV2")
456 .Input("reader_handle: resource")
457 .Input("state: string")
__anon2a9f7d410c02(InferenceContext* c) 458 .SetShapeFn([](InferenceContext* c) {
459 ShapeHandle unused;
460 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
461 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
462 return OkStatus();
463 });
464
465 REGISTER_OP("ReaderReset")
466 .Input("reader_handle: Ref(string)")
467 .SetShapeFn(TwoElementVectorAndScalarOutputs);
468
469 REGISTER_OP("ReaderResetV2")
470 .Input("reader_handle: resource")
471 .SetShapeFn(ScalarInputsAndOutputs);
472
473 // Other input Ops ----------------------------------------------------------
474
475 REGISTER_OP("ReadFile")
476 .Input("filename: string")
477 .Output("contents: string")
478 .SetShapeFn(ScalarInputsAndOutputs);
479
480 REGISTER_OP("WriteFile")
481 .Input("filename: string")
482 .Input("contents: string")
483 .SetIsStateful()
__anon2a9f7d410d02(InferenceContext* c) 484 .SetShapeFn([](InferenceContext* c) {
485 ShapeHandle unused;
486 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
487 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
488 return OkStatus();
489 });
490
491 REGISTER_OP("MatchingFiles")
492 .Input("pattern: string")
493 .Output("filenames: string")
__anon2a9f7d410e02(InferenceContext* c) 494 .SetShapeFn([](InferenceContext* c) {
495 ShapeHandle unused;
496 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
497 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
498 return OkStatus();
499 });
500
501 } // namespace tensorflow
502