xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/tensorflow/checkpoint_writer.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/checkpoint_writer.h"
18 
19 #include <string>
20 #include <vector>
21 
22 #include "absl/status/status.h"
23 #include "absl/strings/str_format.h"
24 #include "fcp/aggregation/core/datatype.h"
25 #include "fcp/aggregation/core/tensor.h"
26 #include "fcp/base/monitoring.h"
27 #include "fcp/tensorflow/status.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/tensor_slice.h"
30 #include "tensorflow/core/platform/tstring.h"
31 
32 namespace fcp::aggregation::tensorflow {
33 
34 namespace tf = ::tensorflow;
35 
ConvertShape(const TensorShape & shape)36 tf::TensorShape ConvertShape(const TensorShape& shape) {
37   tf::TensorShape tf_shape;
38   for (auto dim : shape.dim_sizes()) {
39     tf_shape.AddDim(dim);
40   }
41   FCP_CHECK(tf_shape.IsValid());
42   return tf_shape;
43 }
44 
45 template <typename T>
AddTensorSlice(tf::checkpoint::TensorSliceWriter * writer,const std::string & name,const tf::TensorShape & shape,const tf::TensorSlice & slice,const Tensor & tensor)46 tf::Status AddTensorSlice(tf::checkpoint::TensorSliceWriter* writer,
47                           const std::string& name, const tf::TensorShape& shape,
48                           const tf::TensorSlice& slice, const Tensor& tensor) {
49   return writer->Add<T>(name, shape, slice,
50                         static_cast<const T*>(tensor.data().data()));
51 }
52 
53 template <>
AddTensorSlice(tf::checkpoint::TensorSliceWriter * writer,const std::string & name,const tf::TensorShape & shape,const tf::TensorSlice & slice,const Tensor & tensor)54 tf::Status AddTensorSlice<string_view>(
55     tf::checkpoint::TensorSliceWriter* writer, const std::string& name,
56     const tf::TensorShape& shape, const tf::TensorSlice& slice,
57     const Tensor& tensor) {
58   std::vector<tf::tstring> values(tensor.shape().NumElements());
59   const auto* string_views =
60       static_cast<const string_view*>(tensor.data().data());
61   for (size_t i = 0; i < values.size(); ++i) {
62     values[i].assign_as_view(string_views[i].data(), string_views[i].size());
63   }
64   return writer->Add(name, shape, slice, values.data());
65 }
66 
CheckpointWriter(const std::string & filename)67 CheckpointWriter::CheckpointWriter(const std::string& filename)
68     : tensorflow_writer_(filename,
69                          tf::checkpoint::CreateTableTensorSliceBuilder) {}
70 
CheckpointWriter(const std::string & filename,tf::checkpoint::TensorSliceWriter::CreateBuilderFunction create_builder_fn)71 CheckpointWriter::CheckpointWriter(
72     const std::string& filename,
73     tf::checkpoint::TensorSliceWriter::CreateBuilderFunction create_builder_fn)
74     : tensorflow_writer_(filename, create_builder_fn) {}
75 
Add(const std::string & tensor_name,const Tensor & tensor)76 absl::Status CheckpointWriter::Add(const std::string& tensor_name,
77                                    const Tensor& tensor) {
78   tf::TensorShape tf_shape = ConvertShape(tensor.shape());
79   tf::TensorSlice tf_slice(tf_shape.dims());
80   FCP_CHECK(tensor.is_dense())
81       << "Only dense tensors with one slice are supported";
82   tf::Status tf_status;
83   DTYPE_CASES(tensor.dtype(), T,
84               tf_status = AddTensorSlice<T>(&tensorflow_writer_, tensor_name,
85                                             tf_shape, tf_slice, tensor));
86   return ConvertFromTensorFlowStatus(tf_status);
87 }
88 
Finish()89 absl::Status CheckpointWriter::Finish() {
90   return ConvertFromTensorFlowStatus(tensorflow_writer_.Finish());
91 }
92 
93 }  // namespace fcp::aggregation::tensorflow
94