1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/tensorflow/tensorflow_checkpoint_builder_factory.h"
18
19 #include <cstdint>
20 #include <limits>
21 #include <memory>
22 #include <string>
23 #include <utility>
24
25 #include "absl/random/random.h"
26 #include "absl/status/status.h"
27 #include "absl/status/statusor.h"
28 #include "absl/strings/cord.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/string_view.h"
31 #include "fcp/aggregation/core/tensor.h"
32 #include "fcp/aggregation/protocol/checkpoint_builder.h"
33 #include "fcp/aggregation/tensorflow/checkpoint_writer.h"
34 #include "fcp/base/monitoring.h"
35 #include "fcp/tensorflow/status.h"
36 #include "tensorflow/core/platform/env.h"
37 #include "tensorflow/core/platform/file_system.h"
38 #include "tensorflow/core/platform/status.h"
39
40 namespace fcp::aggregation::tensorflow {
41 namespace {
42
43 using ::tensorflow::Env;
44
45 // A CheckpointBuilder implementation that builds TensorFlow checkpoints using a
46 // CheckpointWriter.
47 class TensorflowCheckpointBuilder : public CheckpointBuilder {
48 public:
TensorflowCheckpointBuilder(std::string filename)49 explicit TensorflowCheckpointBuilder(std::string filename)
50 : filename_(std::move(filename)) {}
51
~TensorflowCheckpointBuilder()52 ~TensorflowCheckpointBuilder() override {
53 Env::Default()->DeleteFile(filename_).IgnoreError();
54 }
55
Add(const std::string & name,const Tensor & tensor)56 absl::Status Add(const std::string& name, const Tensor& tensor) override {
57 return writer_.Add(name, tensor);
58 }
59
Build()60 absl::StatusOr<absl::Cord> Build() override {
61 FCP_RETURN_IF_ERROR(writer_.Finish());
62
63 // Read the checkpoints contents from the file.
64 std::unique_ptr<::tensorflow::RandomAccessFile> file;
65 FCP_RETURN_IF_ERROR(ConvertFromTensorFlowStatus(
66 Env::Default()->NewRandomAccessFile(filename_, &file)));
67
68 absl::Cord output;
69 for (;;) {
70 char scratch[4096];
71 absl::string_view read_result;
72 ::tensorflow::Status status =
73 file->Read(output.size(), sizeof(scratch), &read_result, scratch);
74 output.Append(read_result);
75 if (status.code() == ::tensorflow::error::OUT_OF_RANGE) {
76 return output;
77 } else if (!status.ok()) {
78 return ConvertFromTensorFlowStatus(status);
79 }
80 }
81 }
82
83 private:
84 std::string filename_;
85 CheckpointWriter writer_{filename_};
86 };
87
88 } // namespace
89
Create() const90 std::unique_ptr<CheckpointBuilder> TensorflowCheckpointBuilderFactory::Create()
91 const {
92 // Create a (likely) unique filename in Tensorflow's RamFileSystem. This
93 // results in a second in-memory copy of the data but avoids disk I/O.
94 std::string filename =
95 absl::StrCat("ram://",
96 absl::Hex(absl::Uniform(
97 absl::BitGen(), 0, std::numeric_limits<int64_t>::max())),
98 ".ckpt");
99
100 return std::make_unique<TensorflowCheckpointBuilder>(std::move(filename));
101 }
102
103 } // namespace fcp::aggregation::tensorflow
104