1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""For reading and writing TFRecords files.""" 17 18from tensorflow.python.lib.io import _pywrap_record_io 19from tensorflow.python.util import compat 20from tensorflow.python.util import deprecation 21from tensorflow.python.util.tf_export import tf_export 22 23 24@tf_export( 25 v1=["io.TFRecordCompressionType", "python_io.TFRecordCompressionType"]) 26@deprecation.deprecated_endpoints("io.TFRecordCompressionType", 27 "python_io.TFRecordCompressionType") 28class TFRecordCompressionType(object): 29 """The type of compression for the record.""" 30 NONE = 0 31 ZLIB = 1 32 GZIP = 2 33 34 35@tf_export( 36 "io.TFRecordOptions", 37 v1=["io.TFRecordOptions", "python_io.TFRecordOptions"]) 38@deprecation.deprecated_endpoints("python_io.TFRecordOptions") 39class TFRecordOptions(object): 40 """Options used for manipulating TFRecord files.""" 41 compression_type_map = { 42 TFRecordCompressionType.ZLIB: "ZLIB", 43 TFRecordCompressionType.GZIP: "GZIP", 44 TFRecordCompressionType.NONE: "" 45 } 46 47 def __init__(self, 48 compression_type=None, 49 flush_mode=None, 50 input_buffer_size=None, 51 output_buffer_size=None, 52 window_bits=None, 53 compression_level=None, 54 compression_method=None, 55 mem_level=None, 56 compression_strategy=None): 57 # pylint: disable=line-too-long 58 """Creates a `TFRecordOptions` instance. 59 60 Options only effect TFRecordWriter when compression_type is not `None`. 61 Documentation, details, and defaults can be found in 62 [`zlib_compression_options.h`](https://www.tensorflow.org/code/tensorflow/core/lib/io/zlib_compression_options.h) 63 and in the [zlib manual](http://www.zlib.net/manual.html). 64 Leaving an option as `None` allows C++ to set a reasonable default. 65 66 Args: 67 compression_type: `"GZIP"`, `"ZLIB"`, or `""` (no compression). 68 flush_mode: flush mode or `None`, Default: Z_NO_FLUSH. 69 input_buffer_size: int or `None`. 70 output_buffer_size: int or `None`. 71 window_bits: int or `None`. 72 compression_level: 0 to 9, or `None`. 73 compression_method: compression method or `None`. 74 mem_level: 1 to 9, or `None`. 75 compression_strategy: strategy or `None`. Default: Z_DEFAULT_STRATEGY. 76 77 Returns: 78 A `TFRecordOptions` object. 79 80 Raises: 81 ValueError: If compression_type is invalid. 82 """ 83 # pylint: enable=line-too-long 84 # Check compression_type is valid, but for backwards compatibility don't 85 # immediately convert to a string. 86 self.get_compression_type_string(compression_type) 87 self.compression_type = compression_type 88 self.flush_mode = flush_mode 89 self.input_buffer_size = input_buffer_size 90 self.output_buffer_size = output_buffer_size 91 self.window_bits = window_bits 92 self.compression_level = compression_level 93 self.compression_method = compression_method 94 self.mem_level = mem_level 95 self.compression_strategy = compression_strategy 96 97 @classmethod 98 def get_compression_type_string(cls, options): 99 """Convert various option types to a unified string. 100 101 Args: 102 options: `TFRecordOption`, `TFRecordCompressionType`, or string. 103 104 Returns: 105 Compression type as string (e.g. `'ZLIB'`, `'GZIP'`, or `''`). 106 107 Raises: 108 ValueError: If compression_type is invalid. 109 """ 110 if not options: 111 return "" 112 elif isinstance(options, TFRecordOptions): 113 return cls.get_compression_type_string(options.compression_type) 114 elif isinstance(options, TFRecordCompressionType): 115 return cls.compression_type_map[options] 116 elif options in TFRecordOptions.compression_type_map: 117 return cls.compression_type_map[options] 118 elif options in TFRecordOptions.compression_type_map.values(): 119 return options 120 else: 121 raise ValueError('Not a valid compression_type: "{}"'.format(options)) 122 123 def _as_record_writer_options(self): 124 """Convert to RecordWriterOptions for use with PyRecordWriter.""" 125 options = _pywrap_record_io.RecordWriterOptions( 126 compat.as_bytes( 127 self.get_compression_type_string(self.compression_type))) 128 129 if self.flush_mode is not None: 130 options.zlib_options.flush_mode = self.flush_mode 131 if self.input_buffer_size is not None: 132 options.zlib_options.input_buffer_size = self.input_buffer_size 133 if self.output_buffer_size is not None: 134 options.zlib_options.output_buffer_size = self.output_buffer_size 135 if self.window_bits is not None: 136 options.zlib_options.window_bits = self.window_bits 137 if self.compression_level is not None: 138 options.zlib_options.compression_level = self.compression_level 139 if self.compression_method is not None: 140 options.zlib_options.compression_method = self.compression_method 141 if self.mem_level is not None: 142 options.zlib_options.mem_level = self.mem_level 143 if self.compression_strategy is not None: 144 options.zlib_options.compression_strategy = self.compression_strategy 145 return options 146 147 148@tf_export(v1=["io.tf_record_iterator", "python_io.tf_record_iterator"]) 149@deprecation.deprecated( 150 date=None, 151 instructions=("Use eager execution and: \n" 152 "`tf.data.TFRecordDataset(path)`")) 153def tf_record_iterator(path, options=None): 154 """An iterator that read the records from a TFRecords file. 155 156 Args: 157 path: The path to the TFRecords file. 158 options: (optional) A TFRecordOptions object. 159 160 Returns: 161 An iterator of serialized TFRecords. 162 163 Raises: 164 IOError: If `path` cannot be opened for reading. 165 """ 166 compression_type = TFRecordOptions.get_compression_type_string(options) 167 return _pywrap_record_io.RecordIterator(path, compression_type) 168 169 170def tf_record_random_reader(path): 171 """Creates a reader that allows random-access reads from a TFRecords file. 172 173 The created reader object has the following method: 174 175 - `read(offset)`, which returns a tuple of `(record, ending_offset)`, where 176 `record` is the TFRecord read at the offset, and 177 `ending_offset` is the ending offset of the read record. 178 179 The method throws a `tf.errors.DataLossError` if data is corrupted at 180 the given offset. The method throws `IndexError` if the offset is out of 181 range for the TFRecords file. 182 183 184 Usage example: 185 ```py 186 reader = tf_record_random_reader(file_path) 187 188 record_1, offset_1 = reader.read(0) # 0 is the initial offset. 189 # offset_1 is the ending offset of the 1st record and the starting offset of 190 # the next. 191 192 record_2, offset_2 = reader.read(offset_1) 193 # offset_2 is the ending offset of the 2nd record and the starting offset of 194 # the next. 195 # We can jump back and read the first record again if so desired. 196 reader.read(0) 197 ``` 198 199 Args: 200 path: The path to the TFRecords file. 201 202 Returns: 203 An object that supports random-access reading of the serialized TFRecords. 204 205 Raises: 206 IOError: If `path` cannot be opened for reading. 207 """ 208 return _pywrap_record_io.RandomRecordReader(path) 209 210 211@tf_export( 212 "io.TFRecordWriter", v1=["io.TFRecordWriter", "python_io.TFRecordWriter"]) 213@deprecation.deprecated_endpoints("python_io.TFRecordWriter") 214class TFRecordWriter(_pywrap_record_io.RecordWriter): 215 """A class to write records to a TFRecords file. 216 217 [TFRecords tutorial](https://www.tensorflow.org/tutorials/load_data/tfrecord) 218 219 TFRecords is a binary format which is optimized for high throughput data 220 retrieval, generally in conjunction with `tf.data`. `TFRecordWriter` is used 221 to write serialized examples to a file for later consumption. The key steps 222 are: 223 224 Ahead of time: 225 226 - [Convert data into a serialized format]( 227 https://www.tensorflow.org/tutorials/load_data/tfrecord#tfexample) 228 - [Write the serialized data to one or more files]( 229 https://www.tensorflow.org/tutorials/load_data/tfrecord#tfrecord_files_in_python) 230 231 During training or evaluation: 232 233 - [Read serialized examples into memory]( 234 https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file) 235 - [Parse (deserialize) examples]( 236 https://www.tensorflow.org/tutorials/load_data/tfrecord#reading_a_tfrecord_file) 237 238 A minimal example is given below: 239 240 >>> import tempfile 241 >>> example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords") 242 >>> np.random.seed(0) 243 244 >>> # Write the records to a file. 245 ... with tf.io.TFRecordWriter(example_path) as file_writer: 246 ... for _ in range(4): 247 ... x, y = np.random.random(), np.random.random() 248 ... 249 ... record_bytes = tf.train.Example(features=tf.train.Features(feature={ 250 ... "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])), 251 ... "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])), 252 ... })).SerializeToString() 253 ... file_writer.write(record_bytes) 254 255 >>> # Read the data back out. 256 >>> def decode_fn(record_bytes): 257 ... return tf.io.parse_single_example( 258 ... # Data 259 ... record_bytes, 260 ... 261 ... # Schema 262 ... {"x": tf.io.FixedLenFeature([], dtype=tf.float32), 263 ... "y": tf.io.FixedLenFeature([], dtype=tf.float32)} 264 ... ) 265 266 >>> for batch in tf.data.TFRecordDataset([example_path]).map(decode_fn): 267 ... print("x = {x:.4f}, y = {y:.4f}".format(**batch)) 268 x = 0.5488, y = 0.7152 269 x = 0.6028, y = 0.5449 270 x = 0.4237, y = 0.6459 271 x = 0.4376, y = 0.8918 272 273 This class implements `__enter__` and `__exit__`, and can be used 274 in `with` blocks like a normal file. (See the usage example above.) 275 """ 276 277 # TODO(josh11b): Support appending? 278 def __init__(self, path, options=None): 279 """Opens file `path` and creates a `TFRecordWriter` writing to it. 280 281 Args: 282 path: The path to the TFRecords file. 283 options: (optional) String specifying compression type, 284 `TFRecordCompressionType`, or `TFRecordOptions` object. 285 286 Raises: 287 IOError: If `path` cannot be opened for writing. 288 ValueError: If valid compression_type can't be determined from `options`. 289 """ 290 if not isinstance(options, TFRecordOptions): 291 options = TFRecordOptions(compression_type=options) 292 293 # pylint: disable=protected-access 294 super(TFRecordWriter, self).__init__( 295 compat.as_bytes(path), options._as_record_writer_options()) 296 # pylint: enable=protected-access 297 298 # TODO(slebedev): The following wrapper methods are there to compensate 299 # for lack of signatures in pybind11-generated classes. Switch to 300 # __text_signature__ when TensorFlow drops Python 2.X support. 301 # See https://github.com/pybind/pybind11/issues/945 302 # pylint: disable=useless-super-delegation 303 def write(self, record): 304 """Write a string record to the file. 305 306 Args: 307 record: str 308 """ 309 super(TFRecordWriter, self).write(record) 310 311 def flush(self): 312 """Flush the file.""" 313 super(TFRecordWriter, self).flush() 314 315 def close(self): 316 """Close the file.""" 317 super(TFRecordWriter, self).close() 318 # pylint: enable=useless-super-delegation 319