xref: /aosp_15_r20/external/tensorflow/tensorflow/python/lib/io/record_io_wrapper.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "absl/memory/memory.h"
20 #include "pybind11/pybind11.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/stringpiece.h"
23 #include "tensorflow/core/lib/io/record_reader.h"
24 #include "tensorflow/core/lib/io/record_writer.h"
25 #include "tensorflow/core/lib/io/zlib_compression_options.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/errors.h"
28 #include "tensorflow/core/platform/file_system.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/python/lib/core/pybind11_absl.h"
31 #include "tensorflow/python/lib/core/pybind11_status.h"
32 
33 namespace {
34 
35 namespace py = ::pybind11;
36 
37 class PyRecordReader {
38  public:
39   // NOTE(sethtroisi): At this time PyRecordReader doesn't benefit from taking
40   // RecordReaderOptions, if this changes the API can be updated at that time.
New(const std::string & filename,const std::string & compression_type,PyRecordReader ** out)41   static tensorflow::Status New(const std::string& filename,
42                                 const std::string& compression_type,
43                                 PyRecordReader** out) {
44     auto tmp = new PyRecordReader(filename, compression_type);
45     TF_RETURN_IF_ERROR(tmp->Reopen());
46     *out = tmp;
47     return ::tensorflow::OkStatus();
48   }
49 
50   PyRecordReader() = delete;
~PyRecordReader()51   ~PyRecordReader() { Close(); }
52 
ReadNextRecord(tensorflow::tstring * out)53   tensorflow::Status ReadNextRecord(tensorflow::tstring* out) {
54     if (IsClosed()) {
55       return tensorflow::errors::FailedPrecondition("Reader is closed.");
56     }
57     return reader_->ReadRecord(&offset_, out);
58   }
59 
IsClosed() const60   bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
61 
Close()62   void Close() {
63     reader_ = nullptr;
64     file_ = nullptr;
65   }
66 
67   // Reopen a closed writer by re-opening the file and re-creating the reader,
68   // but preserving the prior read offset. If not closed, returns an error.
69   //
70   // This is useful to allow "refreshing" the underlying file handle, in cases
71   // where the file was replaced with a newer version containing additional data
72   // that otherwise wouldn't be available via the existing file handle. This
73   // allows the file to be polled continuously using the same iterator, even as
74   // it grows, which supports use cases such as TensorBoard.
Reopen()75   tensorflow::Status Reopen() {
76     if (!IsClosed()) {
77       return tensorflow::errors::FailedPrecondition("Reader is not closed.");
78     }
79     TF_RETURN_IF_ERROR(
80         tensorflow::Env::Default()->NewRandomAccessFile(filename_, &file_));
81     reader_ =
82         absl::make_unique<tensorflow::io::RecordReader>(file_.get(), options_);
83     return ::tensorflow::OkStatus();
84   }
85 
86  private:
87   static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
88 
PyRecordReader(const std::string & filename,const std::string & compression_type)89   PyRecordReader(const std::string& filename,
90                  const std::string& compression_type)
91       : filename_(filename),
92         options_(CreateOptions(compression_type)),
93         offset_(0),
94         file_(nullptr),
95         reader_(nullptr) {}
96 
CreateOptions(const std::string & compression_type)97   static tensorflow::io::RecordReaderOptions CreateOptions(
98       const std::string& compression_type) {
99     auto options =
100         tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions(
101             compression_type);
102     options.buffer_size = kReaderBufferSize;
103     return options;
104   }
105 
106   const std::string filename_;
107   const tensorflow::io::RecordReaderOptions options_;
108   tensorflow::uint64 offset_;
109   std::unique_ptr<tensorflow::RandomAccessFile> file_;
110   std::unique_ptr<tensorflow::io::RecordReader> reader_;
111 
112   TF_DISALLOW_COPY_AND_ASSIGN(PyRecordReader);
113 };
114 
115 class PyRecordRandomReader {
116  public:
New(const std::string & filename,PyRecordRandomReader ** out)117   static tensorflow::Status New(const std::string& filename,
118                                 PyRecordRandomReader** out) {
119     std::unique_ptr<tensorflow::RandomAccessFile> file;
120     TF_RETURN_IF_ERROR(
121         tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
122     auto options =
123         tensorflow::io::RecordReaderOptions::CreateRecordReaderOptions("");
124     options.buffer_size = kReaderBufferSize;
125     auto reader =
126         absl::make_unique<tensorflow::io::RecordReader>(file.get(), options);
127     *out = new PyRecordRandomReader(std::move(file), std::move(reader));
128     return ::tensorflow::OkStatus();
129   }
130 
131   PyRecordRandomReader() = delete;
~PyRecordRandomReader()132   ~PyRecordRandomReader() { Close(); }
133 
ReadRecord(tensorflow::uint64 * offset,tensorflow::tstring * out)134   tensorflow::Status ReadRecord(tensorflow::uint64* offset,
135                                 tensorflow::tstring* out) {
136     if (IsClosed()) {
137       return tensorflow::errors::FailedPrecondition(
138           "Random TFRecord Reader is closed.");
139     }
140     return reader_->ReadRecord(offset, out);
141   }
142 
IsClosed() const143   bool IsClosed() const { return file_ == nullptr && reader_ == nullptr; }
144 
Close()145   void Close() {
146     reader_ = nullptr;
147     file_ = nullptr;
148   }
149 
150  private:
151   static constexpr tensorflow::uint64 kReaderBufferSize = 16 * 1024 * 1024;
152 
PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,std::unique_ptr<tensorflow::io::RecordReader> reader)153   PyRecordRandomReader(std::unique_ptr<tensorflow::RandomAccessFile> file,
154                        std::unique_ptr<tensorflow::io::RecordReader> reader)
155       : file_(std::move(file)), reader_(std::move(reader)) {}
156 
157   std::unique_ptr<tensorflow::RandomAccessFile> file_;
158   std::unique_ptr<tensorflow::io::RecordReader> reader_;
159 
160   TF_DISALLOW_COPY_AND_ASSIGN(PyRecordRandomReader);
161 };
162 
163 class PyRecordWriter {
164  public:
New(const std::string & filename,const tensorflow::io::RecordWriterOptions & options,PyRecordWriter ** out)165   static tensorflow::Status New(
166       const std::string& filename,
167       const tensorflow::io::RecordWriterOptions& options,
168       PyRecordWriter** out) {
169     std::unique_ptr<tensorflow::WritableFile> file;
170     TF_RETURN_IF_ERROR(
171         tensorflow::Env::Default()->NewWritableFile(filename, &file));
172     auto writer =
173         absl::make_unique<tensorflow::io::RecordWriter>(file.get(), options);
174     *out = new PyRecordWriter(std::move(file), std::move(writer));
175     return ::tensorflow::OkStatus();
176   }
177 
178   PyRecordWriter() = delete;
~PyRecordWriter()179   ~PyRecordWriter() { Close(); }
180 
WriteRecord(tensorflow::StringPiece record)181   tensorflow::Status WriteRecord(tensorflow::StringPiece record) {
182     if (IsClosed()) {
183       return tensorflow::errors::FailedPrecondition("Writer is closed.");
184     }
185     return writer_->WriteRecord(record);
186   }
187 
Flush()188   tensorflow::Status Flush() {
189     if (IsClosed()) {
190       return tensorflow::errors::FailedPrecondition("Writer is closed.");
191     }
192 
193     auto status = writer_->Flush();
194     if (status.ok()) {
195       // Per the RecordWriter contract, flushing the RecordWriter does not
196       // flush the underlying file.  Here we need to do both.
197       return file_->Flush();
198     }
199     return status;
200   }
201 
IsClosed() const202   bool IsClosed() const { return file_ == nullptr && writer_ == nullptr; }
203 
Close()204   tensorflow::Status Close() {
205     if (writer_ != nullptr) {
206       auto status = writer_->Close();
207       writer_ = nullptr;
208       if (!status.ok()) return status;
209     }
210     if (file_ != nullptr) {
211       auto status = file_->Close();
212       file_ = nullptr;
213       if (!status.ok()) return status;
214     }
215     return ::tensorflow::OkStatus();
216   }
217 
218  private:
PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,std::unique_ptr<tensorflow::io::RecordWriter> writer)219   PyRecordWriter(std::unique_ptr<tensorflow::WritableFile> file,
220                  std::unique_ptr<tensorflow::io::RecordWriter> writer)
221       : file_(std::move(file)), writer_(std::move(writer)) {}
222 
223   std::unique_ptr<tensorflow::WritableFile> file_;
224   std::unique_ptr<tensorflow::io::RecordWriter> writer_;
225 
226   TF_DISALLOW_COPY_AND_ASSIGN(PyRecordWriter);
227 };
228 
PYBIND11_MODULE(_pywrap_record_io,m)229 PYBIND11_MODULE(_pywrap_record_io, m) {
230   py::class_<PyRecordReader>(m, "RecordIterator")
231       .def(py::init(
232           [](const std::string& filename, const std::string& compression_type) {
233             tensorflow::Status status;
234             PyRecordReader* self = nullptr;
235             {
236               py::gil_scoped_release release;
237               status = PyRecordReader::New(filename, compression_type, &self);
238             }
239             MaybeRaiseRegisteredFromStatus(status);
240             return self;
241           }))
242       .def("__iter__", [](const py::object& self) { return self; })
243       .def("__next__",
244            [](PyRecordReader* self) {
245              if (self->IsClosed()) {
246                throw py::stop_iteration();
247              }
248 
249              tensorflow::tstring record;
250              tensorflow::Status status;
251              {
252                py::gil_scoped_release release;
253                status = self->ReadNextRecord(&record);
254              }
255              if (tensorflow::errors::IsOutOfRange(status)) {
256                // Don't close because the file being read could be updated
257                // in-between
258                // __next__ calls.
259                throw py::stop_iteration();
260              }
261              MaybeRaiseRegisteredFromStatus(status);
262              return py::bytes(record);
263            })
264       .def("close", [](PyRecordReader* self) { self->Close(); })
265       .def("reopen", [](PyRecordReader* self) {
266         tensorflow::Status status;
267         {
268           py::gil_scoped_release release;
269           status = self->Reopen();
270         }
271         MaybeRaiseRegisteredFromStatus(status);
272       });
273 
274   py::class_<PyRecordRandomReader>(m, "RandomRecordReader")
275       .def(py::init([](const std::string& filename) {
276         tensorflow::Status status;
277         PyRecordRandomReader* self = nullptr;
278         {
279           py::gil_scoped_release release;
280           status = PyRecordRandomReader::New(filename, &self);
281         }
282         MaybeRaiseRegisteredFromStatus(status);
283         return self;
284       }))
285       .def("read",
286            [](PyRecordRandomReader* self, tensorflow::uint64 offset) {
287              tensorflow::uint64 temp_offset = offset;
288              tensorflow::tstring record;
289              tensorflow::Status status;
290              {
291                py::gil_scoped_release release;
292                status = self->ReadRecord(&temp_offset, &record);
293              }
294              if (tensorflow::errors::IsOutOfRange(status)) {
295                throw py::index_error(tensorflow::strings::StrCat(
296                    "Out of range at reading offset ", offset));
297              }
298              MaybeRaiseRegisteredFromStatus(status);
299              return py::make_tuple(py::bytes(record), temp_offset);
300            })
301       .def("close", [](PyRecordRandomReader* self) { self->Close(); });
302 
303   using tensorflow::io::ZlibCompressionOptions;
304   py::class_<ZlibCompressionOptions>(m, "ZlibCompressionOptions")
305       .def_readwrite("flush_mode", &ZlibCompressionOptions::flush_mode)
306       .def_readwrite("input_buffer_size",
307                      &ZlibCompressionOptions::input_buffer_size)
308       .def_readwrite("output_buffer_size",
309                      &ZlibCompressionOptions::output_buffer_size)
310       .def_readwrite("window_bits", &ZlibCompressionOptions::window_bits)
311       .def_readwrite("compression_level",
312                      &ZlibCompressionOptions::compression_level)
313       .def_readwrite("compression_method",
314                      &ZlibCompressionOptions::compression_method)
315       .def_readwrite("mem_level", &ZlibCompressionOptions::mem_level)
316       .def_readwrite("compression_strategy",
317                      &ZlibCompressionOptions::compression_strategy);
318 
319   using tensorflow::io::RecordWriterOptions;
320   py::class_<RecordWriterOptions>(m, "RecordWriterOptions")
321       .def(py::init(&RecordWriterOptions::CreateRecordWriterOptions))
322       .def_readonly("compression_type", &RecordWriterOptions::compression_type)
323       .def_readonly("zlib_options", &RecordWriterOptions::zlib_options);
324 
325   using tensorflow::MaybeRaiseRegisteredFromStatus;
326 
327   py::class_<PyRecordWriter>(m, "RecordWriter")
328       .def(py::init(
329           [](const std::string& filename, const RecordWriterOptions& options) {
330             PyRecordWriter* self = nullptr;
331             tensorflow::Status status;
332             {
333               py::gil_scoped_release release;
334               status = PyRecordWriter::New(filename, options, &self);
335             }
336             MaybeRaiseRegisteredFromStatus(status);
337             return self;
338           }))
339       .def("__enter__", [](const py::object& self) { return self; })
340       .def("__exit__",
341            [](PyRecordWriter* self, py::args) {
342              MaybeRaiseRegisteredFromStatus(self->Close());
343            })
344       .def(
345           "write",
346           [](PyRecordWriter* self, tensorflow::StringPiece record) {
347             tensorflow::Status status;
348             {
349               py::gil_scoped_release release;
350               status = self->WriteRecord(record);
351             }
352             MaybeRaiseRegisteredFromStatus(status);
353           },
354           py::arg("record"))
355       .def("flush",
356            [](PyRecordWriter* self) {
357              MaybeRaiseRegisteredFromStatus(self->Flush());
358            })
359       .def("close", [](PyRecordWriter* self) {
360         MaybeRaiseRegisteredFromStatus(self->Close());
361       });
362 }
363 
364 }  // namespace
365