1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_WRITER_H_ 18 #define FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_WRITER_H_ 19 20 #include <string> 21 22 #include "absl/status/status.h" 23 #include "fcp/aggregation/core/tensor.h" 24 #include "tensorflow/core/util/tensor_slice_writer.h" 25 26 namespace fcp::aggregation::tensorflow { 27 28 // This class wraps TensorSliceWriter and provides a similar 29 // functionality but accepts Aggregation Core tensors instead. 30 // This class is designed to write only dense tensors that consist of a 31 // single slice. 32 class CheckpointWriter final { 33 public: 34 // CheckpointReader is neither copyable nor moveable 35 CheckpointWriter(const CheckpointWriter&) = delete; 36 CheckpointWriter& operator=(const CheckpointWriter&) = delete; 37 38 // Constructs CheckpointWriter for the given filename. 39 explicit CheckpointWriter(const std::string& filename); 40 41 // Constructs CheckpointWriter for the given filename and 42 // CreateBuilderFunction. 43 explicit CheckpointWriter( 44 const std::string& filename, 45 ::tensorflow::checkpoint::TensorSliceWriter::CreateBuilderFunction 46 create_builder_fn); 47 48 // Adds a tensor to the checkpoint. 49 absl::Status Add(const std::string& tensor_name, const Tensor& tensor); 50 51 // Writes the checkpoint to the file. 52 absl::Status Finish(); 53 54 private: 55 ::tensorflow::checkpoint::TensorSliceWriter tensorflow_writer_; 56 }; 57 58 } // namespace fcp::aggregation::tensorflow 59 60 #endif // FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_WRITER_H_ 61