1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "absl/memory/memory.h"
20 #include "pybind11/pybind11.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/lib/io/record_reader.h"
24 #include "tensorflow/core/lib/io/record_writer.h"
25 #include "tensorflow/core/lib/io/zlib_compression_options.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/file_system.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/python/lib/core/pybind11_absl.h"
31 #include "tensorflow/python/lib/core/pybind11_status.h"
32
33 namespace {
34
35 namespace py = ::pybind11;
36
37 class PyRecordReader {
38 public:
39 // NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
40 // RecordReaderOptions, if this changes the API can be updated at that time.
New(const std::string & filename,const std::string & compression_type,PyRecordReader ** out)41 static tensorflow::Status New(const std::string& filename,
42 const std::string& compression_type,
43 PyRecordReader** out) {
44 auto tmp = new PyRecordReader(filename, compression_type);
45 TF_RETURN_IF_ERROR(tmp->Reopen());
46 *out = tmp;
47 return ::tensorflow::OkStatus();
48 }
49
50 PyRecordReader() = delete;
~PyRecordReader()51 ~PyRecordReader() { Close(); }
52
ReadNextRecord(tensorflow::tstring * out)53 tensorflow::Status ReadNextRecord(tensorflow::tstring* out) {
54 if (IsClosed()) {
55 return tensorflow::errors::FailedPrecondition("Reader is closed.");
56 }
57 return reader_->ReadRecord(&offset_, out);
58 }
59
IsClosed() const60 bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
61
Close()62 void Close() {
63 reader_ = nullptr;
64 file_ = nullptr;
65 }
66
67 // Reopen a closed writer by re-opening the file and re-creating the reader,
68 // but preserving the prior read offset. If not closed, returns an error.
69 //
70 // This is useful to allow "refreshing" the underlying file handle, in cases
71 // where the file was replaced with a newer version containing additional data
72 // that otherwise wouldn't be available via the existing file handle. This
73 // allows the file to be polled continuously using the same iterator, even as
74 // it grows, which supports use cases such as TensorBoard.
Reopen()75 tensorflow::Status Reopen() {
76 if (!IsClosed()) {
77 return tensorflow::errors::FailedPrecondition("Reader is not closed.");
78 }
79 TF_RETURN_IF_ERROR(
80 tensorflow::Env::Default()->NewRandomAccessFile(filename_, &file_));
81 reader_ =
82 absl::make_unique<tensorflow::io::RecordReader>(file_.get(), options_);
83 return ::tensorflow::OkStatus();
84 }
85
86 private:
87 static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
88
PyRecordReader(const std::string & filename,const std::string & compression_type)89 PyRecordReader(const std::string& filename,
90 const std::string& compression_type)
91 : filename_(filename),
92 options_(CreateOptions(compression_type)),
93 offset_(0),
94 file_(nullptr),
95 reader_(nullptr) {}
96
CreateOptions(const std::string & compression_type)97 static tensorflow::io::RecordReaderOptions CreateOptions(
98 const std::string& compression_type) {
99 auto options =
100 tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions(
101 compression_type);
102 options.buffer_size = kReaderBufferSize;
103 return options;
104 }
105
106 const std::string filename_;
107 const tensorflow::io::RecordReaderOptions options_;
108 tensorflow::uint64 offset_;
109 std::unique_ptr<tensorflow::RandomAccessFile> file_;
110 std::unique_ptr<tensorflow::io::RecordReader> reader_;
111
112 TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
113 };
114
115 class PyRecordRandomReader {
116 public:
New(const std::string & filename,PyRecordRandomReader ** out)117 static tensorflow::Status New(const std::string& filename,
118 PyRecordRandomReader** out) {
119 std::unique_ptr<tensorflow::RandomAccessFile> file;
120 TF_RETURN_IF_ERROR(
121 tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
122 auto options =
123 tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("");
124 options.buffer_size = kReaderBufferSize;
125 auto reader =
126 absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
127 *out = new PyRecordRandomReader(std::move(file), std::move(reader));
128 return ::tensorflow::OkStatus();
129 }
130
131 PyRecordRandomReader() = delete;
~PyRecordRandomReader()132 ~PyRecordRandomReader() { Close(); }
133
ReadRecord(tensorflow::uint64 * offset,tensorflow::tstring * out)134 tensorflow::Status ReadRecord(tensorflow::uint64* offset,
135 tensorflow::tstring* out) {
136 if (IsClosed()) {
137 return tensorflow::errors::FailedPrecondition(
138 "Random TFRecord Reader is closed.");
139 }
140 return reader_->ReadRecord(offset, out);
141 }
142
IsClosed() const143 bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
144
Close()145 void Close() {
146 reader_ = nullptr;
147 file_ = nullptr;
148 }
149
150 private:
151 static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
152
PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,std::unique_ptr<tensorflow::io::RecordReader> reader)153 PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
154 std::unique_ptr<tensorflow::io::RecordReader> reader)
155 : file_(std::move(file)), reader_(std::move(reader)) {}
156
157 std::unique_ptr<tensorflow::RandomAccessFile> file_;
158 std::unique_ptr<tensorflow::io::RecordReader> reader_;
159
160 TF_DISALLOW_COPY_AND_ASSIGN(PyRecordRandomReader);
161 };
162
163 class PyRecordWriter {
164 public:
New(const std::string & filename,const tensorflow::io::RecordWriterOptions & options,PyRecordWriter ** out)165 static tensorflow::Status New(
166 const std::string& filename,
167 const tensorflow::io::RecordWriterOptions& options,
168 PyRecordWriter** out) {
169 std::unique_ptr<tensorflow::WritableFile> file;
170 TF_RETURN_IF_ERROR(
171 tensorflow::Env::Default()->NewWritableFile(filename, &file));
172 auto writer =
173 absl::make_unique<tensorflow::io::RecordWriter>(file.get(), options);
174 *out = new PyRecordWriter(std::move(file), std::move(writer));
175 return ::tensorflow::OkStatus();
176 }
177
178 PyRecordWriter() = delete;
~PyRecordWriter()179 ~PyRecordWriter() { Close(); }
180
WriteRecord(tensorflow::StringPiece record)181 tensorflow::Status WriteRecord(tensorflow::StringPiece record) {
182 if (IsClosed()) {
183 return tensorflow::errors::FailedPrecondition("Writer is closed.");
184 }
185 return writer_->WriteRecord(record);
186 }
187
Flush()188 tensorflow::Status Flush() {
189 if (IsClosed()) {
190 return tensorflow::errors::FailedPrecondition("Writer is closed.");
191 }
192
193 auto status = writer_->Flush();
194 if (status.ok()) {
195 // Per the RecordWriter contract, flushing the RecordWriter does not
196 // flush the underlying file. Here we need to do both.
197 return file_->Flush();
198 }
199 return status;
200 }
201
IsClosed() const202 bool IsClosed() const { return file_ == nullptr && writer_ == nullptr; }
203
Close()204 tensorflow::Status Close() {
205 if (writer_ != nullptr) {
206 auto status = writer_->Close();
207 writer_ = nullptr;
208 if (!status.ok()) return status;
209 }
210 if (file_ != nullptr) {
211 auto status = file_->Close();
212 file_ = nullptr;
213 if (!status.ok()) return status;
214 }
215 return ::tensorflow::OkStatus();
216 }
217
218 private:
PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,std::unique_ptr<tensorflow::io::RecordWriter> writer)219 PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,
220 std::unique_ptr<tensorflow::io::RecordWriter> writer)
221 : file_(std::move(file)), writer_(std::move(writer)) {}
222
223 std::unique_ptr<tensorflow::WritableFile> file_;
224 std::unique_ptr<tensorflow::io::RecordWriter> writer_;
225
226 TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
227 };
228
PYBIND11_MODULE(_pywrap_record_io,m)229 PYBIND11_MODULE(_pywrap_record_io, m) {
230 py::class_<PyRecordReader>(m, "RecordIterator")
231 .def(py::init(
232 [](const std::string& filename, const std::string& compression_type) {
233 tensorflow::Status status;
234 PyRecordReader* self = nullptr;
235 {
236 py::gil_scoped_release release;
237 status = PyRecordReader::New(filename, compression_type, &self);
238 }
239 MaybeRaiseRegisteredFromStatus(status);
240 return self;
241 }))
242 .def("__iter__", [](const py::object& self) { return self; })
243 .def("__next__",
244 [](PyRecordReader* self) {
245 if (self->IsClosed()) {
246 throw py::stop_iteration();
247 }
248
249 tensorflow::tstring record;
250 tensorflow::Status status;
251 {
252 py::gil_scoped_release release;
253 status = self->ReadNextRecord(&record);
254 }
255 if (tensorflow::errors::IsOutOfRange(status)) {
256 // Don't close because the file being read could be updated
257 // in-between
258 // __next__ calls.
259 throw py::stop_iteration();
260 }
261 MaybeRaiseRegisteredFromStatus(status);
262 return py::bytes(record);
263 })
264 .def("close", [](PyRecordReader* self) { self->Close(); })
265 .def("reopen", [](PyRecordReader* self) {
266 tensorflow::Status status;
267 {
268 py::gil_scoped_release release;
269 status = self->Reopen();
270 }
271 MaybeRaiseRegisteredFromStatus(status);
272 });
273
274 py::class_<PyRecordRandomReader>(m, "RandomRecordReader")
275 .def(py::init([](const std::string& filename) {
276 tensorflow::Status status;
277 PyRecordRandomReader* self = nullptr;
278 {
279 py::gil_scoped_release release;
280 status = PyRecordRandomReader::New(filename, &self);
281 }
282 MaybeRaiseRegisteredFromStatus(status);
283 return self;
284 }))
285 .def("read",
286 [](PyRecordRandomReader* self, tensorflow::uint64 offset) {
287 tensorflow::uint64 temp_offset = offset;
288 tensorflow::tstring record;
289 tensorflow::Status status;
290 {
291 py::gil_scoped_release release;
292 status = self->ReadRecord(&temp_offset, &record);
293 }
294 if (tensorflow::errors::IsOutOfRange(status)) {
295 throw py::index_error(tensorflow::strings::StrCat(
296 "Out of range at reading offset ", offset));
297 }
298 MaybeRaiseRegisteredFromStatus(status);
299 return py::make_tuple(py::bytes(record), temp_offset);
300 })
301 .def("close", [](PyRecordRandomReader* self) { self->Close(); });
302
303 using tensorflow::io::ZlibCompressionOptions;
304 py::class_<ZlibCompressionOptions>(m, "ZlibCompressionOptions")
305 .def_readwrite("flush_mode", &ZlibCompressionOptions::flush_mode)
306 .def_readwrite("input_buffer_size",
307 &ZlibCompressionOptions::input_buffer_size)
308 .def_readwrite("output_buffer_size",
309 &ZlibCompressionOptions::output_buffer_size)
310 .def_readwrite("window_bits", &ZlibCompressionOptions::window_bits)
311 .def_readwrite("compression_level",
312 &ZlibCompressionOptions::compression_level)
313 .def_readwrite("compression_method",
314 &ZlibCompressionOptions::compression_method)
315 .def_readwrite("mem_level", &ZlibCompressionOptions::mem_level)
316 .def_readwrite("compression_strategy",
317 &ZlibCompressionOptions::compression_strategy);
318
319 using tensorflow::io::RecordWriterOptions;
320 py::class_<RecordWriterOptions>(m, "RecordWriterOptions")
321 .def(py::init(&RecordWriterOptions::CreateRecordWriterOptions))
322 .def_readonly("compression_type", &RecordWriterOptions::compression_type)
323 .def_readonly("zlib_options", &RecordWriterOptions::zlib_options);
324
325 using tensorflow::MaybeRaiseRegisteredFromStatus;
326
327 py::class_<PyRecordWriter>(m, "RecordWriter")
328 .def(py::init(
329 [](const std::string& filename, const RecordWriterOptions& options) {
330 PyRecordWriter* self = nullptr;
331 tensorflow::Status status;
332 {
333 py::gil_scoped_release release;
334 status = PyRecordWriter::New(filename, options, &self);
335 }
336 MaybeRaiseRegisteredFromStatus(status);
337 return self;
338 }))
339 .def("__enter__", [](const py::object& self) { return self; })
340 .def("__exit__",
341 [](PyRecordWriter* self, py::args) {
342 MaybeRaiseRegisteredFromStatus(self->Close());
343 })
344 .def(
345 "write",
346 [](PyRecordWriter* self, tensorflow::StringPiece record) {
347 tensorflow::Status status;
348 {
349 py::gil_scoped_release release;
350 status = self->WriteRecord(record);
351 }
352 MaybeRaiseRegisteredFromStatus(status);
353 },
354 py::arg("record"))
355 .def("flush",
356 [](PyRecordWriter* self) {
357 MaybeRaiseRegisteredFromStatus(self->Flush());
358 })
359 .def("close", [](PyRecordWriter* self) {
360 MaybeRaiseRegisteredFromStatus(self->Close());
361 });
362 }
363
364 } // namespace
365