1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h"
18
19 #include <stdint.h>
20
21 #include <limits>
22 #include <memory>
23 #include <string>
24 #include <utility>
25
26 #include "absl/cleanup/cleanup.h"
27 #include "absl/random/random.h"
28 #include "absl/status/status.h"
29 #include "absl/status/statusor.h"
30 #include "absl/strings/cord.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/string_view.h"
33 #include "fcp/aggregation/core/tensor.h"
34 #include "fcp/aggregation/protocol/checkpoint_parser.h"
35 #include "fcp/aggregation/tensorflow/checkpoint_reader.h"
36 #include "fcp/base/monitoring.h"
37 #include "fcp/tensorflow/status.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/platform/file_system.h"
40
41 namespace fcp::aggregation::tensorflow {
42 namespace {
43
44 using ::tensorflow::Env;
45
46 // A CheckpointParser implementation that reads TensorFlow checkpoints using a
47 // CheckpointReader.
48 class TensorflowCheckpointParser : public CheckpointParser {
49 public:
TensorflowCheckpointParser(std::string filename,std::unique_ptr<CheckpointReader> reader)50 TensorflowCheckpointParser(std::string filename,
51 std::unique_ptr<CheckpointReader> reader)
52 : filename_(std::move(filename)), reader_(std::move(reader)) {}
53
~TensorflowCheckpointParser()54 ~TensorflowCheckpointParser() override {
55 Env::Default()->DeleteFile(filename_).IgnoreError();
56 }
57
GetTensor(const std::string & name) const58 absl::StatusOr<Tensor> GetTensor(const std::string& name) const override {
59 return reader_->GetTensor(name);
60 }
61
62 private:
63 std::string filename_;
64 std::unique_ptr<CheckpointReader> reader_;
65 };
66
67 } // namespace
68
69 absl::StatusOr<std::unique_ptr<CheckpointParser>>
Create(const absl::Cord & serialized_checkpoint) const70 TensorflowCheckpointParserFactory::Create(
71 const absl::Cord& serialized_checkpoint) const {
72 // Create a (likely) unique filename in Tensorflow's RamFileSystem. This
73 // results in a second in-memory copy of the data but avoids disk I/O.
74 std::string filename =
75 absl::StrCat("ram://",
76 absl::Hex(absl::Uniform(
77 absl::BitGen(), 0, std::numeric_limits<int64_t>::max())),
78 ".ckpt");
79
80 // Write the checkpoint to the temporary file.
81 std::unique_ptr<::tensorflow::WritableFile> file;
82 FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(
83 Env::Default()->NewWritableFile(filename, &file)));
84 absl::Cleanup cleanup = [&] {
85 Env::Default()->DeleteFile(filename).IgnoreError();
86 };
87 for (absl::string_view chunk : serialized_checkpoint.Chunks()) {
88 FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(file->Append(chunk)));
89 }
90 FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(file->Close()));
91
92 // Return a TensorflowCheckpointParser that will read from the file.
93 FCP_ASSIGN_OR_RETURN(std::unique_ptr<CheckpointReader> reader,
94 CheckpointReader::Create(filename));
95 std::move(cleanup).Cancel();
96 return std::make_unique<TensorflowCheckpointParser>(std::move(filename),
97 std::move(reader));
98 }
99
100 } // namespace fcp::aggregation::tensorflow
101