1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_ 18 #define FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_ 19 20 #include <memory> 21 22 #include "fcp/aggregation/core/datatype.h" 23 #include "fcp/aggregation/core/tensor.h" 24 #include "fcp/aggregation/core/tensor_shape.h" 25 #include "fcp/aggregation/core/tensor_spec.h" 26 #include "fcp/base/monitoring.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_shape.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/protobuf/struct.pb.h" 31 32 namespace fcp::aggregation::tensorflow { 33 34 // Converts Tensorflow DataType to Aggregation DataType. 35 // Returns an error status if the input data type isn't supported by 36 // the Aggregation Core. 37 StatusOr<DataType> ConvertDataType(::tensorflow::DataType dtype); 38 39 // Converts Tensorflow TensorShape to Aggregation TensorShape. 40 // Note that the Tensorflow shape is expected to be valid (it seems impossible 41 // to create an invalid shape). 42 TensorShape ConvertShape(const ::tensorflow::TensorShape& shape); 43 44 // Converts Tensorflow TensorSpecProto to Aggregation TensorSpec. 45 // Returns an error status if supplied TensorSpecProto data type or shape isn't 46 // supported by the Aggregation Core. 47 StatusOr<TensorSpec> ConvertTensorSpec( 48 const ::tensorflow::TensorSpecProto& spec); 49 50 // Converts Tensorflow Tensor to Aggregation Tensor. 51 // Returns an error status if supplied Tensor data type or shape isn't 52 // supported by the Aggregation Core. 53 // Note that this function consumes the Tensorflow tensor. 54 StatusOr<Tensor> ConvertTensor(std::unique_ptr<::tensorflow::Tensor> tensor); 55 56 } // namespace fcp::aggregation::tensorflow 57 58 #endif // FCP_AGGREGATION_TENSORFLOW_CONVERTERS_H_ 59