xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/tensorflow/converters.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/converters.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "fcp/aggregation/core/datatype.h"
25 #include "fcp/aggregation/core/tensor.pb.h"
26 #include "fcp/base/monitoring.h"
27 #include "tensorflow/core/framework/tensor.h"
28 #include "tensorflow/core/framework/tensor_shape.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 
31 namespace fcp::aggregation::tensorflow {
32 
33 namespace tf = ::tensorflow;
34 
ConvertDataType(tf::DataType dtype)35 StatusOr<DataType> ConvertDataType(tf::DataType dtype) {
36   switch (dtype) {
37     case tf::DT_FLOAT:
38       return DT_FLOAT;
39     case tf::DT_DOUBLE:
40       return DT_DOUBLE;
41     case tf::DT_INT32:
42       return DT_INT32;
43     case tf::DT_INT64:
44       return DT_INT64;
45     case tf::DT_STRING:
46       return DT_STRING;
47     default:
48       return FCP_STATUS(INVALID_ARGUMENT)
49              << "Unsupported tf::DataType: " << dtype;
50   }
51 }
52 
ConvertShape(const tf::TensorShape & shape)53 TensorShape ConvertShape(const tf::TensorShape& shape) {
54   FCP_CHECK(shape.IsFullyDefined());
55   std::vector<size_t> dim_sizes;
56   for (auto dim_size : shape.dim_sizes()) {
57     FCP_CHECK(dim_size >= 0);
58     dim_sizes.push_back(dim_size);
59   }
60   return TensorShape(dim_sizes.begin(), dim_sizes.end());
61 }
62 
ConvertTensorSpec(const::tensorflow::TensorSpecProto & spec)63 StatusOr<TensorSpec> ConvertTensorSpec(
64     const ::tensorflow::TensorSpecProto& spec) {
65   FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(spec.dtype()));
66   tf::TensorShape tf_shape;
67   if (!tf::TensorShape::BuildTensorShape(spec.shape(), &tf_shape).ok()) {
68     return FCP_STATUS(INVALID_ARGUMENT)
69            << "Unsupported tf::TensorShape: " << spec.shape().DebugString();
70   }
71   return TensorSpec(spec.name(), dtype, ConvertShape(tf_shape));
72 }
73 
74 // A primitive TensorData implementation that wraps the original
75 // tf::Tensor data.
76 // NumericTensorDataAdapter gets the ownership of the wrapped tensor, which
77 // keeps the underlying data alive.
78 class NumericTensorDataAdapter : public TensorData {
79  public:
NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)80   explicit NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
81       : tensor_(std::move(tensor)) {}
82 
83   // The source tf::Tensor has the data as one continuous blob.
byte_size() const84   size_t byte_size() const override { return tensor_->tensor_data().size(); }
data() const85   const void* data() const override { return tensor_->tensor_data().data(); }
86 
87  private:
88   std::unique_ptr<tf::Tensor> tensor_;
89 };
90 
91 // Similar to  NumericTensorDataAdapter but performs additional conversion
92 // of the original tensor tstring values to string_view while keeping the
93 // the tstring values owned by the original tensor.
94 class StringTensorDataAdapter : public TensorData {
95  public:
StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)96   explicit StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
97       : tensor_(std::move(tensor)), string_views_(tensor_->NumElements()) {
98     auto string_values = tensor_->flat<tf::tstring>();
99     for (size_t i = 0; i < string_values.size(); ++i) {
100       string_views_[i] = string_values(i);
101     }
102   }
103 
byte_size() const104   size_t byte_size() const override {
105     return string_views_.size() * sizeof(string_view);
106   }
data() const107   const void* data() const override { return string_views_.data(); }
108 
109  private:
110   std::unique_ptr<tf::Tensor> tensor_;
111   std::vector<string_view> string_views_;
112 };
113 
114 // Conversion of tensor data for numeric data types, which can be
115 // done by simply wrapping the original tensorflow tensor data.
116 template <typename t>
ConvertTensorData(std::unique_ptr<tf::Tensor> tensor)117 std::unique_ptr<TensorData> ConvertTensorData(
118     std::unique_ptr<tf::Tensor> tensor) {
119   return std::make_unique<NumericTensorDataAdapter>(std::move(tensor));
120 }
121 
122 // Specialization of ConvertTensorData for the DT_STRING data type.
123 template <>
ConvertTensorData(std::unique_ptr<tf::Tensor> tensor)124 std::unique_ptr<TensorData> ConvertTensorData<string_view>(
125     std::unique_ptr<tf::Tensor> tensor) {
126   return std::make_unique<StringTensorDataAdapter>(std::move(tensor));
127 }
128 
ConvertTensor(std::unique_ptr<tf::Tensor> tensor)129 StatusOr<Tensor> ConvertTensor(std::unique_ptr<tf::Tensor> tensor) {
130   FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tensor->dtype()));
131   TensorShape shape = ConvertShape(tensor->shape());
132   std::unique_ptr<TensorData> data;
133   DTYPE_CASES(dtype, T, data = ConvertTensorData<T>(std::move(tensor)));
134   return Tensor::Create(dtype, std::move(shape), std::move(data));
135 }
136 
137 }  // namespace fcp::aggregation::tensorflow
138