1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/tensorflow_checkpoint_parser_factory.h"
18 
19 #include <stdint.h>
20 
21 #include <limits>
22 #include <memory>
23 #include <string>
24 #include <utility>
25 
26 #include "absl/cleanup/cleanup.h"
27 #include "absl/random/random.h"
28 #include "absl/status/status.h"
29 #include "absl/status/statusor.h"
30 #include "absl/strings/cord.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/string_view.h"
33 #include "fcp/aggregation/core/tensor.h"
34 #include "fcp/aggregation/protocol/checkpoint_parser.h"
35 #include "fcp/aggregation/tensorflow/checkpoint_reader.h"
36 #include "fcp/base/monitoring.h"
37 #include "fcp/tensorflow/status.h"
38 #include "tensorflow/core/platform/env.h"
39 #include "tensorflow/core/platform/file_system.h"
40 
41 namespace fcp::aggregation::tensorflow {
42 namespace {
43 
44 using ::tensorflow::Env;
45 
46 // A CheckpointParser implementation that reads TensorFlow checkpoints using a
47 // CheckpointReader.
48 class TensorflowCheckpointParser : public CheckpointParser {
49  public:
TensorflowCheckpointParser(std::string filename,std::unique_ptr<CheckpointReader> reader)50   TensorflowCheckpointParser(std::string filename,
51                              std::unique_ptr<CheckpointReader> reader)
52       : filename_(std::move(filename)), reader_(std::move(reader)) {}
53 
~TensorflowCheckpointParser()54   ~TensorflowCheckpointParser() override {
55     Env::Default()->DeleteFile(filename_).IgnoreError();
56   }
57 
GetTensor(const std::string & name) const58   absl::StatusOr<Tensor> GetTensor(const std::string& name) const override {
59     return reader_->GetTensor(name);
60   }
61 
62  private:
63   std::string filename_;
64   std::unique_ptr<CheckpointReader> reader_;
65 };
66 
67 }  // namespace
68 
69 absl::StatusOr<std::unique_ptr<CheckpointParser>>
Create(const absl::Cord & serialized_checkpoint) const70 TensorflowCheckpointParserFactory::Create(
71     const absl::Cord& serialized_checkpoint) const {
72   // Create a (likely) unique filename in Tensorflow's RamFileSystem. This
73   // results in a second in-memory copy of the data but avoids disk I/O.
74   std::string filename =
75       absl::StrCat("ram://",
76                    absl::Hex(absl::Uniform(
77                        absl::BitGen(), 0, std::numeric_limits<int64_t>::max())),
78                    ".ckpt");
79 
80   // Write the checkpoint to the temporary file.
81   std::unique_ptr<::tensorflow::WritableFile> file;
82   FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(
83       Env::Default()->NewWritableFile(filename, &file)));
84   absl::Cleanup cleanup = [&] {
85     Env::Default()->DeleteFile(filename).IgnoreError();
86   };
87   for (absl::string_view chunk : serialized_checkpoint.Chunks()) {
88     FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(file->Append(chunk)));
89   }
90   FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(file->Close()));
91 
92   // Return a TensorflowCheckpointParser that will read from the file.
93   FCP_ASSIGN_OR_RETURN(std::unique_ptr<CheckpointReader> reader,
94                        CheckpointReader::Create(filename));
95   std::move(cleanup).Cancel();
96   return std::make_unique<TensorflowCheckpointParser>(std::move(filename),
97                                                       std::move(reader));
98 }
99 
100 }  // namespace fcp::aggregation::tensorflow
101