xref: /aosp_15_r20/external/tensorflow/tensorflow/python/lib/io/tf_record.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""For reading and writing TFRecords files."""
17
18from tensorflow.python.lib.io import _pywrap_record_io
19from tensorflow.python.util import compat
20from tensorflow.python.util import deprecation
21from tensorflow.python.util.tf_export import tf_export
22
23
24@tf_export(
25    v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"])
26@deprecation.deprecated_endpoints("io.TFRecordCompressionType",
27                                  "python_io.TFRecordCompressionType")
28class TFRecordCompressionType(object):
29  """The type of compression for the record."""
30  NONE = 0
31  ZLIB = 1
32  GZIP = 2
33
34
35@tf_export(
36    "io.TFRecordOptions",
37    v1=["io.TFRecordOptions", "python_io.TFRecordOptions"])
38@deprecation.deprecated_endpoints("python_io.TFRecordOptions")
39class TFRecordOptions(object):
40  """Options used for manipulating TFRecord files."""
41  compression_type_map = {
42      TFRecordCompressionType.ZLIB: "ZLIB",
43      TFRecordCompressionType.GZIP: "GZIP",
44      TFRecordCompressionType.NONE: ""
45  }
46
47  def __init__(self,
48               compression_type=None,
49               flush_mode=None,
50               input_buffer_size=None,
51               output_buffer_size=None,
52               window_bits=None,
53               compression_level=None,
54               compression_method=None,
55               mem_level=None,
56               compression_strategy=None):
57    # pylint: disable=line-too-long
58    """Creates a `TFRecordOptions` instance.
59
60    Options only effect TFRecordWriter when compression_type is not `None`.
61    Documentation, details, and defaults can be found in
62    [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h)
63    and in the [zlib manual](http://www.zlib.net/manual.html).
64    Leaving an option as `None` allows C++ to set a reasonable default.
65
66    Args:
67      compression_type: `"GZIP"`, `"ZLIB"`, or `""` (no compression).
68      flush_mode: flush mode or `None`, Default: Z_NO_FLUSH.
69      input_buffer_size: int or `None`.
70      output_buffer_size: int or `None`.
71      window_bits: int or `None`.
72      compression_level: 0 to 9, or `None`.
73      compression_method: compression method or `None`.
74      mem_level: 1 to 9, or `None`.
75      compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY.
76
77    Returns:
78      A `TFRecordOptions` object.
79
80    Raises:
81      ValueError: If compression_type is invalid.
82    """
83    # pylint: enable=line-too-long
84    # Check compression_type is valid, but for backwards compatibility don't
85    # immediately convert to a string.
86    self.get_compression_type_string(compression_type)
87    self.compression_type = compression_type
88    self.flush_mode = flush_mode
89    self.input_buffer_size = input_buffer_size
90    self.output_buffer_size = output_buffer_size
91    self.window_bits = window_bits
92    self.compression_level = compression_level
93    self.compression_method = compression_method
94    self.mem_level = mem_level
95    self.compression_strategy = compression_strategy
96
97  @classmethod
98  def get_compression_type_string(cls, options):
99    """Convert various option types to a unified string.
100
101    Args:
102      options: `TFRecordOption`, `TFRecordCompressionType`, or string.
103
104    Returns:
105      Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`).
106
107    Raises:
108      ValueError: If compression_type is invalid.
109    """
110    if not options:
111      return ""
112    elif isinstance(options, TFRecordOptions):
113      return cls.get_compression_type_string(options.compression_type)
114    elif isinstance(options, TFRecordCompressionType):
115      return cls.compression_type_map[options]
116    elif options in TFRecordOptions.compression_type_map:
117      return cls.compression_type_map[options]
118    elif options in TFRecordOptions.compression_type_map.values():
119      return options
120    else:
121      raise ValueError('Not a valid compression_type: "{}"'.format(options))
122
123  def _as_record_writer_options(self):
124    """Convert to RecordWriterOptions for use with PyRecordWriter."""
125    options = _pywrap_record_io.RecordWriterOptions(
126        compat.as_bytes(
127            self.get_compression_type_string(self.compression_type)))
128
129    if self.flush_mode is not None:
130      options.zlib_options.flush_mode = self.flush_mode
131    if self.input_buffer_size is not None:
132      options.zlib_options.input_buffer_size = self.input_buffer_size
133    if self.output_buffer_size is not None:
134      options.zlib_options.output_buffer_size = self.output_buffer_size
135    if self.window_bits is not None:
136      options.zlib_options.window_bits = self.window_bits
137    if self.compression_level is not None:
138      options.zlib_options.compression_level = self.compression_level
139    if self.compression_method is not None:
140      options.zlib_options.compression_method = self.compression_method
141    if self.mem_level is not None:
142      options.zlib_options.mem_level = self.mem_level
143    if self.compression_strategy is not None:
144      options.zlib_options.compression_strategy = self.compression_strategy
145    return options
146
147
148@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"])
149@deprecation.deprecated(
150    date=None,
151    instructions=("Use eager execution and: \n"
152                  "`tf.data.TFRecordDataset(path)`"))
153def tf_record_iterator(path, options=None):
154  """An iterator that read the records from a TFRecords file.
155
156  Args:
157    path: The path to the TFRecords file.
158    options: (optional) A TFRecordOptions object.
159
160  Returns:
161    An iterator of serialized TFRecords.
162
163  Raises:
164    IOError: If `path` cannot be opened for reading.
165  """
166  compression_type = TFRecordOptions.get_compression_type_string(options)
167  return _pywrap_record_io.RecordIterator(path, compression_type)
168
169
170def tf_record_random_reader(path):
171  """Creates a reader that allows random-access reads from a TFRecords file.
172
173  The created reader object has the following method:
174
175    - `read(offset)`, which returns a tuple of `(record, ending_offset)`, where
176      `record` is the TFRecord read at the offset, and
177      `ending_offset` is the ending offset of the read record.
178
179      The method throws a `tf.errors.DataLossError` if data is corrupted at
180      the given offset. The method throws `IndexError` if the offset is out of
181      range for the TFRecords file.
182
183
184  Usage example:
185  ```py
186  reader = tf_record_random_reader(file_path)
187
188  record_1, offset_1 = reader.read(0)  # 0 is the initial offset.
189  # offset_1 is the ending offset of the 1st record and the starting offset of
190  # the next.
191
192  record_2, offset_2 = reader.read(offset_1)
193  # offset_2 is the ending offset of the 2nd record and the starting offset of
194  # the next.
195  # We can jump back and read the first record again if so desired.
196  reader.read(0)
197  ```
198
199  Args:
200    path: The path to the TFRecords file.
201
202  Returns:
203    An object that supports random-access reading of the serialized TFRecords.
204
205  Raises:
206    IOError: If `path` cannot be opened for reading.
207  """
208  return _pywrap_record_io.RandomRecordReader(path)
209
210
211@tf_export(
212    "io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"])
213@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
214class TFRecordWriter(_pywrap_record_io.RecordWriter):
215  """A class to write records to a TFRecords file.
216
217  [TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord)
218
219  TFRecords is a binary format which is optimized for high throughput data
220  retrieval, generally in conjunction with `tf.data`. `TFRecordWriter` is used
221  to write serialized examples to a file for later consumption. The key steps
222  are:
223
224   Ahead of time:
225
226   - [Convert data into a serialized format](
227   https://www.tensorflow.org/tutorials/load_data/tfrecord#tfexample)
228   - [Write the serialized data to one or more files](
229   https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_in_python)
230
231   During training or evaluation:
232
233   - [Read serialized examples into memory](
234   https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file)
235   - [Parse (deserialize) examples](
236   https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file)
237
238  A minimal example is given below:
239
240  >>> import tempfile
241  >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
242  >>> np.random.seed(0)
243
244  >>> # Write the records to a file.
245  ... with tf.io.TFRecordWriter(example_path) as file_writer:
246  ...   for _ in range(4):
247  ...     x, y = np.random.random(), np.random.random()
248  ...
249  ...     record_bytes = tf.train.Example(features=tf.train.Features(feature={
250  ...         "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
251  ...         "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
252  ...     })).SerializeToString()
253  ...     file_writer.write(record_bytes)
254
255  >>> # Read the data back out.
256  >>> def decode_fn(record_bytes):
257  ...   return tf.io.parse_single_example(
258  ...       # Data
259  ...       record_bytes,
260  ...
261  ...       # Schema
262  ...       {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
263  ...        "y": tf.io.FixedLenFeature([], dtype=tf.float32)}
264  ...   )
265
266  >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn):
267  ...   print("x = {x:.4f},  y = {y:.4f}".format(**batch))
268  x = 0.5488,  y = 0.7152
269  x = 0.6028,  y = 0.5449
270  x = 0.4237,  y = 0.6459
271  x = 0.4376,  y = 0.8918
272
273  This class implements `__enter__` and `__exit__`, and can be used
274  in `with` blocks like a normal file. (See the usage example above.)
275  """
276
277  # TODO(josh11b): Support appending?
278  def __init__(self, path, options=None):
279    """Opens file `path` and creates a `TFRecordWriter` writing to it.
280
281    Args:
282      path: The path to the TFRecords file.
283      options: (optional) String specifying compression type,
284          `TFRecordCompressionType`, or `TFRecordOptions` object.
285
286    Raises:
287      IOError: If `path` cannot be opened for writing.
288      ValueError: If valid compression_type can't be determined from `options`.
289    """
290    if not isinstance(options, TFRecordOptions):
291      options = TFRecordOptions(compression_type=options)
292
293    # pylint: disable=protected-access
294    super(TFRecordWriter, self).__init__(
295        compat.as_bytes(path), options._as_record_writer_options())
296    # pylint: enable=protected-access
297
298  # TODO(slebedev): The following wrapper methods are there to compensate
299  # for lack of signatures in pybind11-generated classes. Switch to
300  # __text_signature__ when TensorFlow drops Python 2.X support.
301  # See https://github.com/pybind/pybind11/issues/945
302  # pylint: disable=useless-super-delegation
303  def write(self, record):
304    """Write a string record to the file.
305
306    Args:
307      record: str
308    """
309    super(TFRecordWriter, self).write(record)
310
311  def flush(self):
312    """Flush the file."""
313    super(TFRecordWriter, self).flush()
314
315  def close(self):
316    """Close the file."""
317    super(TFRecordWriter, self).close()
318  # pylint: enable=useless-super-delegation
319