xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/readers.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python wrappers for reader Datasets."""
16import os
17
18from tensorflow.python import tf2
19from tensorflow.python.data.ops import dataset_ops
20from tensorflow.python.data.ops import structured_function
21from tensorflow.python.data.util import convert
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_spec
26from tensorflow.python.framework import type_spec
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import gen_dataset_ops
29from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
30from tensorflow.python.util import nest
31from tensorflow.python.util.tf_export import tf_export
32
33_DEFAULT_READER_BUFFER_SIZE_BYTES = 256 * 1024  # 256 KB
34
35
36def _normalise_fspath(path):
37  """Convert pathlib-like objects to str (__fspath__ compatibility, PEP 519)."""
38  return os.fspath(path) if isinstance(path, os.PathLike) else path
39
40
41def _create_or_validate_filenames_dataset(filenames, name=None):
42  """Creates (or validates) a dataset of filenames.
43
44  Args:
45    filenames: Either a list or dataset of filenames. If it is a list, it is
46      convert to a dataset. If it is a dataset, its type and shape is validated.
47    name: (Optional.) A name for the tf.data operation.
48
49  Returns:
50    A dataset of filenames.
51  """
52  if isinstance(filenames, dataset_ops.DatasetV2):
53    element_type = dataset_ops.get_legacy_output_types(filenames)
54    if element_type != dtypes.string:
55      raise TypeError(
56          "The `filenames` argument must contain `tf.string` elements. Got a "
57          f"dataset of `{element_type!r}` elements.")
58    element_shape = dataset_ops.get_legacy_output_shapes(filenames)
59    if not element_shape.is_compatible_with(tensor_shape.TensorShape([])):
60      raise TypeError(
61          "The `filenames` argument must contain `tf.string` elements of shape "
62          "[] (i.e. scalars). Got a dataset of element shape "
63          f"{element_shape!r}.")
64  else:
65    filenames = nest.map_structure(_normalise_fspath, filenames)
66    filenames = ops.convert_to_tensor(filenames, dtype_hint=dtypes.string)
67    if filenames.dtype != dtypes.string:
68      raise TypeError(
69          "The `filenames` argument must contain `tf.string` elements. Got "
70          f"`{filenames.dtype!r}` elements.")
71    filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
72    filenames = dataset_ops.TensorSliceDataset(
73        filenames, is_files=True, name=name)
74  return filenames
75
76
77def _create_dataset_reader(dataset_creator,
78                           filenames,
79                           num_parallel_reads=None,
80                           name=None):
81  """Creates a dataset that reads the given files using the given reader.
82
83  Args:
84    dataset_creator: A function that takes in a single file name and returns a
85      dataset.
86    filenames: A `tf.data.Dataset` containing one or more filenames.
87    num_parallel_reads: The number of parallel reads we should do.
88    name: (Optional.) A name for the tf.data operation.
89
90  Returns:
91    A `Dataset` that reads data from `filenames`.
92  """
93
94  def read_one_file(filename):
95    filename = ops.convert_to_tensor(filename, dtypes.string, name="filename")
96    return dataset_creator(filename)
97
98  if num_parallel_reads is None:
99    return filenames.flat_map(read_one_file, name=name)
100  elif num_parallel_reads == dataset_ops.AUTOTUNE:
101    return filenames.interleave(
102        read_one_file, num_parallel_calls=num_parallel_reads, name=name)
103  else:
104    return ParallelInterleaveDataset(
105        filenames,
106        read_one_file,
107        cycle_length=num_parallel_reads,
108        block_length=1,
109        sloppy=False,
110        buffer_output_elements=None,
111        prefetch_input_elements=None,
112        name=name)
113
114
115def _get_type(value):
116  """Returns the type of `value` if it is a TypeSpec."""
117
118  if isinstance(value, type_spec.TypeSpec):
119    return value.value_type()
120  else:
121    return type(value)
122
123
124class _TextLineDataset(dataset_ops.DatasetSource):
125  """A `Dataset` comprising records from one or more text files."""
126
127  def __init__(self,
128               filenames,
129               compression_type=None,
130               buffer_size=None,
131               name=None):
132    """Creates a `TextLineDataset`.
133
134    Args:
135      filenames: A `tf.string` tensor containing one or more filenames.
136      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
137        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
138      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
139        to buffer. A value of 0 results in the default buffering values chosen
140        based on the compression type.
141      name: (Optional.) A name for the tf.data operation.
142    """
143    self._filenames = filenames
144    self._compression_type = convert.optional_param_to_tensor(
145        "compression_type",
146        compression_type,
147        argument_default="",
148        argument_dtype=dtypes.string)
149    self._buffer_size = convert.optional_param_to_tensor(
150        "buffer_size",
151        buffer_size,
152        argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
153    self._name = name
154
155    variant_tensor = gen_dataset_ops.text_line_dataset(
156        self._filenames,
157        self._compression_type,
158        self._buffer_size,
159        metadata=self._metadata.SerializeToString())
160    super(_TextLineDataset, self).__init__(variant_tensor)
161
162  @property
163  def element_spec(self):
164    return tensor_spec.TensorSpec([], dtypes.string)
165
166
167@tf_export("data.TextLineDataset", v1=[])
168class TextLineDatasetV2(dataset_ops.DatasetSource):
169  r"""Creates a `Dataset` comprising lines from one or more text files.
170
171  The `tf.data.TextLineDataset` loads text from text files and creates a dataset
172  where each line of the files becomes an element of the dataset.
173
174  For example, suppose we have 2 files "text_lines0.txt" and "text_lines1.txt"
175  with the following lines:
176
177  >>> with open('/tmp/text_lines0.txt', 'w') as f:
178  ...   f.write('the cow\n')
179  ...   f.write('jumped over\n')
180  ...   f.write('the moon\n')
181  >>> with open('/tmp/text_lines1.txt', 'w') as f:
182  ...   f.write('jack and jill\n')
183  ...   f.write('went up\n')
184  ...   f.write('the hill\n')
185
186  We can construct a TextLineDataset from them as follows:
187
188  >>> dataset = tf.data.TextLineDataset(['/tmp/text_lines0.txt',
189  ...                                    '/tmp/text_lines1.txt'])
190
191  The elements of the dataset are expected to be:
192
193  >>> for element in dataset.as_numpy_iterator():
194  ...   print(element)
195  b'the cow'
196  b'jumped over'
197  b'the moon'
198  b'jack and jill'
199  b'went up'
200  b'the hill'
201  """
202
203  def __init__(self,
204               filenames,
205               compression_type=None,
206               buffer_size=None,
207               num_parallel_reads=None,
208               name=None):
209    r"""Creates a `TextLineDataset`.
210
211    The elements of the dataset will be the lines of the input files, using
212    the newline character '\n' to denote line splits. The newline characters
213    will be stripped off of each element.
214
215    Args:
216      filenames: A `tf.data.Dataset` whose elements are `tf.string` scalars, a
217        `tf.string` tensor, or a value that can be converted to a `tf.string`
218        tensor (such as a list of Python strings).
219      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
220        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
221      buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
222        to buffer. A value of 0 results in the default buffering values chosen
223        based on the compression type.
224      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
225        number of files to read in parallel. If greater than one, the records of
226        files read in parallel are outputted in an interleaved order. If your
227        input pipeline is I/O bottlenecked, consider setting this parameter to a
228        value greater than one to parallelize the I/O. If `None`, files will be
229        read sequentially.
230      name: (Optional.) A name for the tf.data operation.
231    """
232    filenames = _create_or_validate_filenames_dataset(filenames, name=name)
233    self._filenames = filenames
234    self._compression_type = compression_type
235    self._buffer_size = buffer_size
236
237    def creator_fn(filename):
238      return _TextLineDataset(
239          filename, compression_type, buffer_size, name=name)
240
241    self._impl = _create_dataset_reader(
242        creator_fn, filenames, num_parallel_reads, name=name)
243    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
244
245    super(TextLineDatasetV2, self).__init__(variant_tensor)
246
247  @property
248  def element_spec(self):
249    return tensor_spec.TensorSpec([], dtypes.string)
250
251
252@tf_export(v1=["data.TextLineDataset"])
253class TextLineDatasetV1(dataset_ops.DatasetV1Adapter):
254  """A `Dataset` comprising lines from one or more text files."""
255
256  def __init__(self,
257               filenames,
258               compression_type=None,
259               buffer_size=None,
260               num_parallel_reads=None,
261               name=None):
262    wrapped = TextLineDatasetV2(filenames, compression_type, buffer_size,
263                                num_parallel_reads, name)
264    super(TextLineDatasetV1, self).__init__(wrapped)
265
266  __init__.__doc__ = TextLineDatasetV2.__init__.__doc__
267
268  @property
269  def _filenames(self):
270    return self._dataset._filenames  # pylint: disable=protected-access
271
272  @_filenames.setter
273  def _filenames(self, value):
274    self._dataset._filenames = value  # pylint: disable=protected-access
275
276
277class _TFRecordDataset(dataset_ops.DatasetSource):
278  """A `Dataset` comprising records from one or more TFRecord files."""
279
280  def __init__(self,
281               filenames,
282               compression_type=None,
283               buffer_size=None,
284               name=None):
285    """Creates a `TFRecordDataset`.
286
287    Args:
288      filenames: A `tf.string` tensor containing one or more filenames.
289      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
290        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
291      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
292        bytes in the read buffer. 0 means no buffering.
293      name: (Optional.) A name for the tf.data operation.
294    """
295    self._filenames = filenames
296    self._compression_type = convert.optional_param_to_tensor(
297        "compression_type",
298        compression_type,
299        argument_default="",
300        argument_dtype=dtypes.string)
301    self._buffer_size = convert.optional_param_to_tensor(
302        "buffer_size",
303        buffer_size,
304        argument_default=_DEFAULT_READER_BUFFER_SIZE_BYTES)
305    self._name = name
306
307    variant_tensor = gen_dataset_ops.tf_record_dataset(
308        self._filenames, self._compression_type, self._buffer_size,
309        metadata=self._metadata.SerializeToString())
310    super(_TFRecordDataset, self).__init__(variant_tensor)
311
312  @property
313  def element_spec(self):
314    return tensor_spec.TensorSpec([], dtypes.string)
315
316
317class ParallelInterleaveDataset(dataset_ops.UnaryDataset):
318  """A `Dataset` that maps a function over its input and flattens the result."""
319
320  def __init__(self,
321               input_dataset,
322               map_func,
323               cycle_length,
324               block_length,
325               sloppy,
326               buffer_output_elements,
327               prefetch_input_elements,
328               name=None):
329    """See `tf.data.experimental.parallel_interleave()` for details."""
330    self._input_dataset = input_dataset
331    self._map_func = structured_function.StructuredFunctionWrapper(
332        map_func, self._transformation_name(), dataset=input_dataset)
333    if not isinstance(self._map_func.output_structure, dataset_ops.DatasetSpec):
334      raise TypeError(
335          "The `map_func` argument must return a `Dataset` object. Got "
336          f"{_get_type(self._map_func.output_structure)!r}.")
337    self._element_spec = self._map_func.output_structure._element_spec  # pylint: disable=protected-access
338    self._cycle_length = ops.convert_to_tensor(
339        cycle_length, dtype=dtypes.int64, name="cycle_length")
340    self._block_length = ops.convert_to_tensor(
341        block_length, dtype=dtypes.int64, name="block_length")
342    self._buffer_output_elements = convert.optional_param_to_tensor(
343        "buffer_output_elements",
344        buffer_output_elements,
345        argument_default=2 * block_length)
346    self._prefetch_input_elements = convert.optional_param_to_tensor(
347        "prefetch_input_elements",
348        prefetch_input_elements,
349        argument_default=2 * cycle_length)
350    if sloppy is None:
351      self._deterministic = "default"
352    elif sloppy:
353      self._deterministic = "false"
354    else:
355      self._deterministic = "true"
356    self._name = name
357
358    variant_tensor = ged_ops.legacy_parallel_interleave_dataset_v2(
359        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
360        self._map_func.function.captured_inputs,
361        self._cycle_length,
362        self._block_length,
363        self._buffer_output_elements,
364        self._prefetch_input_elements,
365        f=self._map_func.function,
366        deterministic=self._deterministic,
367        **self._common_args)
368    super(ParallelInterleaveDataset, self).__init__(input_dataset,
369                                                    variant_tensor)
370
371  def _functions(self):
372    return [self._map_func]
373
374  @property
375  def element_spec(self):
376    return self._element_spec
377
378  def _transformation_name(self):
379    return "tf.data.experimental.parallel_interleave()"
380
381
382@tf_export("data.TFRecordDataset", v1=[])
383class TFRecordDatasetV2(dataset_ops.DatasetV2):
384  """A `Dataset` comprising records from one or more TFRecord files.
385
386  This dataset loads TFRecords from the files as bytes, exactly as they were
387  written.`TFRecordDataset` does not do any parsing or decoding on its own.
388  Parsing and decoding can be done by applying `Dataset.map` transformations
389  after the `TFRecordDataset`.
390
391  A minimal example is given below:
392
393  >>> import tempfile
394  >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
395  >>> np.random.seed(0)
396
397  >>> # Write the records to a file.
398  ... with tf.io.TFRecordWriter(example_path) as file_writer:
399  ...   for _ in range(4):
400  ...     x, y = np.random.random(), np.random.random()
401  ...
402  ...     record_bytes = tf.train.Example(features=tf.train.Features(feature={
403  ...         "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
404  ...         "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
405  ...     })).SerializeToString()
406  ...     file_writer.write(record_bytes)
407
408  >>> # Read the data back out.
409  >>> def decode_fn(record_bytes):
410  ...   return tf.io.parse_single_example(
411  ...       # Data
412  ...       record_bytes,
413  ...
414  ...       # Schema
415  ...       {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
416  ...        "y": tf.io.FixedLenFeature([], dtype=tf.float32)}
417  ...   )
418
419  >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn):
420  ...   print("x = {x:.4f},  y = {y:.4f}".format(**batch))
421  x = 0.5488,  y = 0.7152
422  x = 0.6028,  y = 0.5449
423  x = 0.4237,  y = 0.6459
424  x = 0.4376,  y = 0.8918
425  """
426
427  def __init__(self,
428               filenames,
429               compression_type=None,
430               buffer_size=None,
431               num_parallel_reads=None,
432               name=None):
433    """Creates a `TFRecordDataset` to read one or more TFRecord files.
434
435    Each element of the dataset will contain a single TFRecord.
436
437    Args:
438      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
439        more filenames.
440      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
441        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
442      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
443        bytes in the read buffer. If your input pipeline is I/O bottlenecked,
444        consider setting this parameter to a value 1-100 MBs. If `None`, a
445        sensible default for both local and remote file systems is used.
446      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
447        number of files to read in parallel. If greater than one, the records of
448        files read in parallel are outputted in an interleaved order. If your
449        input pipeline is I/O bottlenecked, consider setting this parameter to a
450        value greater than one to parallelize the I/O. If `None`, files will be
451        read sequentially.
452      name: (Optional.) A name for the tf.data operation.
453
454    Raises:
455      TypeError: If any argument does not have the expected type.
456      ValueError: If any argument does not have the expected shape.
457    """
458    filenames = _create_or_validate_filenames_dataset(filenames, name=name)
459
460    self._filenames = filenames
461    self._compression_type = compression_type
462    self._buffer_size = buffer_size
463    self._num_parallel_reads = num_parallel_reads
464
465    def creator_fn(filename):
466      return _TFRecordDataset(
467          filename, compression_type, buffer_size, name=name)
468
469    self._impl = _create_dataset_reader(
470        creator_fn, filenames, num_parallel_reads, name=name)
471    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
472    super(TFRecordDatasetV2, self).__init__(variant_tensor)
473
474  def _inputs(self):
475    return self._impl._inputs()  # pylint: disable=protected-access
476
477  @property
478  def element_spec(self):
479    return tensor_spec.TensorSpec([], dtypes.string)
480
481
482@tf_export(v1=["data.TFRecordDataset"])
483class TFRecordDatasetV1(dataset_ops.DatasetV1Adapter):
484  """A `Dataset` comprising records from one or more TFRecord files."""
485
486  def __init__(self,
487               filenames,
488               compression_type=None,
489               buffer_size=None,
490               num_parallel_reads=None,
491               name=None):
492    wrapped = TFRecordDatasetV2(
493        filenames, compression_type, buffer_size, num_parallel_reads, name=name)
494    super(TFRecordDatasetV1, self).__init__(wrapped)
495
496  __init__.__doc__ = TFRecordDatasetV2.__init__.__doc__
497
498  @property
499  def _filenames(self):
500    return self._dataset._filenames  # pylint: disable=protected-access
501
502  @_filenames.setter
503  def _filenames(self, value):
504    self._dataset._filenames = value  # pylint: disable=protected-access
505
506
507class _FixedLengthRecordDataset(dataset_ops.DatasetSource):
508  """A `Dataset` of fixed-length records from one or more binary files."""
509
510  def __init__(self,
511               filenames,
512               record_bytes,
513               header_bytes=None,
514               footer_bytes=None,
515               buffer_size=None,
516               compression_type=None,
517               name=None):
518    """Creates a `FixedLengthRecordDataset`.
519
520    Args:
521      filenames: A `tf.string` tensor containing one or more filenames.
522      record_bytes: A `tf.int64` scalar representing the number of bytes in each
523        record.
524      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
525        bytes to skip at the start of a file.
526      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
527        bytes to ignore at the end of a file.
528      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
529        bytes to buffer when reading.
530      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
531        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
532      name: (Optional.) A name for the tf.data operation.
533    """
534    self._filenames = filenames
535    self._record_bytes = ops.convert_to_tensor(
536        record_bytes, dtype=dtypes.int64, name="record_bytes")
537    self._header_bytes = convert.optional_param_to_tensor(
538        "header_bytes", header_bytes)
539    self._footer_bytes = convert.optional_param_to_tensor(
540        "footer_bytes", footer_bytes)
541    self._buffer_size = convert.optional_param_to_tensor(
542        "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
543    self._compression_type = convert.optional_param_to_tensor(
544        "compression_type",
545        compression_type,
546        argument_default="",
547        argument_dtype=dtypes.string)
548    self._name = name
549
550    variant_tensor = gen_dataset_ops.fixed_length_record_dataset_v2(
551        self._filenames,
552        self._header_bytes,
553        self._record_bytes,
554        self._footer_bytes,
555        self._buffer_size,
556        self._compression_type,
557        metadata=self._metadata.SerializeToString())
558    super(_FixedLengthRecordDataset, self).__init__(variant_tensor)
559
560  @property
561  def element_spec(self):
562    return tensor_spec.TensorSpec([], dtypes.string)
563
564
565@tf_export("data.FixedLengthRecordDataset", v1=[])
566class FixedLengthRecordDatasetV2(dataset_ops.DatasetSource):
567  """A `Dataset` of fixed-length records from one or more binary files.
568
569  The `tf.data.FixedLengthRecordDataset` reads fixed length records from binary
570  files and creates a dataset where each record becomes an element of the
571  dataset. The binary files can have a fixed length header and a fixed length
572  footer, which will both be skipped.
573
574  For example, suppose we have 2 files "fixed_length0.bin" and
575  "fixed_length1.bin" with the following content:
576
577  >>> with open('/tmp/fixed_length0.bin', 'wb') as f:
578  ...   f.write(b'HEADER012345FOOTER')
579  >>> with open('/tmp/fixed_length1.bin', 'wb') as f:
580  ...   f.write(b'HEADER6789abFOOTER')
581
582  We can construct a `FixedLengthRecordDataset` from them as follows:
583
584  >>> dataset1 = tf.data.FixedLengthRecordDataset(
585  ...     filenames=['/tmp/fixed_length0.bin', '/tmp/fixed_length1.bin'],
586  ...     record_bytes=2, header_bytes=6, footer_bytes=6)
587
588  The elements of the dataset are:
589
590  >>> for element in dataset1.as_numpy_iterator():
591  ...   print(element)
592  b'01'
593  b'23'
594  b'45'
595  b'67'
596  b'89'
597  b'ab'
598  """
599
600  def __init__(self,
601               filenames,
602               record_bytes,
603               header_bytes=None,
604               footer_bytes=None,
605               buffer_size=None,
606               compression_type=None,
607               num_parallel_reads=None,
608               name=None):
609    """Creates a `FixedLengthRecordDataset`.
610
611    Args:
612      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
613        more filenames.
614      record_bytes: A `tf.int64` scalar representing the number of bytes in each
615        record.
616      header_bytes: (Optional.) A `tf.int64` scalar representing the number of
617        bytes to skip at the start of a file.
618      footer_bytes: (Optional.) A `tf.int64` scalar representing the number of
619        bytes to ignore at the end of a file.
620      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
621        bytes to buffer when reading.
622      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
623        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
624      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
625        number of files to read in parallel. If greater than one, the records of
626        files read in parallel are outputted in an interleaved order. If your
627        input pipeline is I/O bottlenecked, consider setting this parameter to a
628        value greater than one to parallelize the I/O. If `None`, files will be
629        read sequentially.
630      name: (Optional.) A name for the tf.data operation.
631    """
632    filenames = _create_or_validate_filenames_dataset(filenames, name=name)
633
634    self._filenames = filenames
635    self._record_bytes = record_bytes
636    self._header_bytes = header_bytes
637    self._footer_bytes = footer_bytes
638    self._buffer_size = buffer_size
639    self._compression_type = compression_type
640
641    def creator_fn(filename):
642      return _FixedLengthRecordDataset(
643          filename,
644          record_bytes,
645          header_bytes,
646          footer_bytes,
647          buffer_size,
648          compression_type,
649          name=name)
650
651    self._impl = _create_dataset_reader(
652        creator_fn, filenames, num_parallel_reads, name=name)
653    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
654    super(FixedLengthRecordDatasetV2, self).__init__(variant_tensor)
655
656  @property
657  def element_spec(self):
658    return tensor_spec.TensorSpec([], dtypes.string)
659
660
661@tf_export(v1=["data.FixedLengthRecordDataset"])
662class FixedLengthRecordDatasetV1(dataset_ops.DatasetV1Adapter):
663  """A `Dataset` of fixed-length records from one or more binary files."""
664
665  def __init__(self,
666               filenames,
667               record_bytes,
668               header_bytes=None,
669               footer_bytes=None,
670               buffer_size=None,
671               compression_type=None,
672               num_parallel_reads=None,
673               name=None):
674    wrapped = FixedLengthRecordDatasetV2(
675        filenames,
676        record_bytes,
677        header_bytes,
678        footer_bytes,
679        buffer_size,
680        compression_type,
681        num_parallel_reads,
682        name=name)
683    super(FixedLengthRecordDatasetV1, self).__init__(wrapped)
684
685  __init__.__doc__ = FixedLengthRecordDatasetV2.__init__.__doc__
686
687  @property
688  def _filenames(self):
689    return self._dataset._filenames  # pylint: disable=protected-access
690
691  @_filenames.setter
692  def _filenames(self, value):
693    self._dataset._filenames = value  # pylint: disable=protected-access
694
695
696if tf2.enabled():
697  FixedLengthRecordDataset = FixedLengthRecordDatasetV2
698  TFRecordDataset = TFRecordDatasetV2
699  TextLineDataset = TextLineDatasetV2
700else:
701  FixedLengthRecordDataset = FixedLengthRecordDatasetV1
702  TFRecordDataset = TFRecordDatasetV1
703  TextLineDataset = TextLineDatasetV1
704