1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
18 
19 #include <cstdint>
20 #include <limits>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 
25 #include "absl/random/random.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/cord.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "fcp/aggregation/core/tensor.h"
32 #include "fcp/aggregation/protocol/checkpoint_builder.h"
33 #include "fcp/aggregation/tensorflow/checkpoint_writer.h"
34 #include "fcp/base/monitoring.h"
35 #include "fcp/tensorflow/status.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/file_system.h"
38 #include "tensorflow/core/platform/status.h"
39 
40 namespace fcp::aggregation::tensorflow {
41 namespace {
42 
43 using ::tensorflow::Env;
44 
45 // A CheckpointBuilder implementation that builds TensorFlow checkpoints using a
46 // CheckpointWriter.
47 class TensorflowCheckpointBuilder : public CheckpointBuilder {
48  public:
TensorflowCheckpointBuilder(std::string filename)49   explicit TensorflowCheckpointBuilder(std::string filename)
50       : filename_(std::move(filename)) {}
51 
~TensorflowCheckpointBuilder()52   ~TensorflowCheckpointBuilder() override {
53     Env::Default()->DeleteFile(filename_).IgnoreError();
54   }
55 
Add(const std::string & name,const Tensor & tensor)56   absl::Status Add(const std::string& name, const Tensor& tensor) override {
57     return writer_.Add(name, tensor);
58   }
59 
Build()60   absl::StatusOr<absl::Cord> Build() override {
61     FCP_RETURN_IF_ERROR(writer_.Finish());
62 
63     // Read the checkpoints contents from the file.
64     std::unique_ptr<::tensorflow::RandomAccessFile> file;
65     FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(
66         Env::Default()->NewRandomAccessFile(filename_, &file)));
67 
68     absl::Cord output;
69     for (;;) {
70       char scratch[4096];
71       absl::string_view read_result;
72       ::tensorflow::Status status =
73           file->Read(output.size(), sizeof(scratch), &read_result, scratch);
74       output.Append(read_result);
75       if (status.code() == ::tensorflow::error::OUT_OF_RANGE) {
76         return output;
77       } else if (!status.ok()) {
78         return ConvertFromTensorFlowStatus(status);
79       }
80     }
81   }
82 
83  private:
84   std::string filename_;
85   CheckpointWriter writer_{filename_};
86 };
87 
88 }  // namespace
89 
Create() const90 std::unique_ptr<CheckpointBuilder> TensorflowCheckpointBuilderFactory::Create()
91     const {
92   // Create a (likely) unique filename in Tensorflow's RamFileSystem. This
93   // results in a second in-memory copy of the data but avoids disk I/O.
94   std::string filename =
95       absl::StrCat("ram://",
96                    absl::Hex(absl::Uniform(
97                        absl::BitGen(), 0, std::numeric_limits<int64_t>::max())),
98                    ".ckpt");
99 
100   return std::make_unique<TensorflowCheckpointBuilder>(std::move(filename));
101 }
102 
103 }  // namespace fcp::aggregation::tensorflow
104