1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/tensorflow/checkpoint_writer.h"
18
19 #include <string>
20 #include <vector>
21
22 #include "absl/status/status.h"
23 #include "absl/strings/str_format.h"
24 #include "fcp/aggregation/core/datatype.h"
25 #include "fcp/aggregation/core/tensor.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/tensorflow/status.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/tensor_slice.h"
30 #include "tensorflow/core/platform/tstring.h"
31
32 namespace fcp::aggregation::tensorflow {
33
34 namespace tf = ::tensorflow;
35
ConvertShape(const TensorShape & shape)36 tf::TensorShape ConvertShape(const TensorShape& shape) {
37 tf::TensorShape tf_shape;
38 for (auto dim : shape.dim_sizes()) {
39 tf_shape.AddDim(dim);
40 }
41 FCP_CHECK(tf_shape.IsValid());
42 return tf_shape;
43 }
44
45 template <typename T>
AddTensorSlice(tf::checkpoint::TensorSliceWriter * writer,const std::string & name,const tf::TensorShape & shape,const tf::TensorSlice & slice,const Tensor & tensor)46 tf::Status AddTensorSlice(tf::checkpoint::TensorSliceWriter* writer,
47 const std::string& name, const tf::TensorShape& shape,
48 const tf::TensorSlice& slice, const Tensor& tensor) {
49 return writer->Add<T>(name, shape, slice,
50 static_cast<const T*>(tensor.data().data()));
51 }
52
53 template <>
AddTensorSlice(tf::checkpoint::TensorSliceWriter * writer,const std::string & name,const tf::TensorShape & shape,const tf::TensorSlice & slice,const Tensor & tensor)54 tf::Status AddTensorSlice<string_view>(
55 tf::checkpoint::TensorSliceWriter* writer, const std::string& name,
56 const tf::TensorShape& shape, const tf::TensorSlice& slice,
57 const Tensor& tensor) {
58 std::vector<tf::tstring> values(tensor.shape().NumElements());
59 const auto* string_views =
60 static_cast<const string_view*>(tensor.data().data());
61 for (size_t i = 0; i < values.size(); ++i) {
62 values[i].assign_as_view(string_views[i].data(), string_views[i].size());
63 }
64 return writer->Add(name, shape, slice, values.data());
65 }
66
CheckpointWriter(const std::string & filename)67 CheckpointWriter::CheckpointWriter(const std::string& filename)
68 : tensorflow_writer_(filename,
69 tf::checkpoint::CreateTableTensorSliceBuilder) {}
70
CheckpointWriter(const std::string & filename,tf::checkpoint::TensorSliceWriter::CreateBuilderFunction create_builder_fn)71 CheckpointWriter::CheckpointWriter(
72 const std::string& filename,
73 tf::checkpoint::TensorSliceWriter::CreateBuilderFunction create_builder_fn)
74 : tensorflow_writer_(filename, create_builder_fn) {}
75
Add(const std::string & tensor_name,const Tensor & tensor)76 absl::Status CheckpointWriter::Add(const std::string& tensor_name,
77 const Tensor& tensor) {
78 tf::TensorShape tf_shape = ConvertShape(tensor.shape());
79 tf::TensorSlice tf_slice(tf_shape.dims());
80 FCP_CHECK(tensor.is_dense())
81 << "Only dense tensors with one slice are supported";
82 tf::Status tf_status;
83 DTYPE_CASES(tensor.dtype(), T,
84 tf_status = AddTensorSlice<T>(&tensorflow_writer_, tensor_name,
85 tf_shape, tf_slice, tensor));
86 return ConvertFromTensorFlowStatus(tf_status);
87 }
88
Finish()89 absl::Status CheckpointWriter::Finish() {
90 return ConvertFromTensorFlowStatus(tensorflow_writer_.Finish());
91 }
92
93 } // namespace fcp::aggregation::tensorflow
94