/* * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_READER_H_ #define FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_READER_H_ #include #include #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" #include "fcp/aggregation/core/datatype.h" #include "fcp/aggregation/core/tensor.h" #include "fcp/aggregation/core/tensor_shape.h" #include "tensorflow/c/checkpoint_reader.h" namespace fcp::aggregation::tensorflow { // This class wraps Tensorflow checkpoint reader and provides a similar // functionality but returns Aggregation Core tensors instead. // This class is designed to read only dense tensors that consist of a // single slice. class CheckpointReader final { public: // CheckpointReader is neither copyable nor moveable CheckpointReader(const CheckpointReader&) = delete; CheckpointReader& operator=(const CheckpointReader&) = delete; using DataTypeMap = absl::flat_hash_map; using TensorShapeMap = absl::flat_hash_map; static absl::StatusOr> Create( const std::string& filename); const DataTypeMap& GetDataTypeMap() const { return data_type_map_; } const TensorShapeMap& GetTensorShapeMap() const { return shape_map_; } absl::StatusOr GetTensor(const std::string& name) const; private: CheckpointReader(std::unique_ptr<::tensorflow::checkpoint::CheckpointReader> tensorflow_checkpoint_reader, DataTypeMap data_type_map, TensorShapeMap shape_map); std::unique_ptr<::tensorflow::checkpoint::CheckpointReader> tf_checkpoint_reader_; DataTypeMap data_type_map_; TensorShapeMap shape_map_; }; } // namespace fcp::aggregation::tensorflow #endif // FCP_AGGREGATION_TENSORFLOW_CHECKPOINT_READER_H_