xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/tensorflow/converters.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2022 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/tensorflow/converters.h"
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <utility>
21*14675a02SAndroid Build Coastguard Worker #include <vector>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/datatype.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/tensor.pb.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
27*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h"
28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.h"
29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.pb.h"
30*14675a02SAndroid Build Coastguard Worker 
31*14675a02SAndroid Build Coastguard Worker namespace fcp::aggregation::tensorflow {
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker namespace tf = ::tensorflow;
34*14675a02SAndroid Build Coastguard Worker 
ConvertDataType(tf::DataType dtype)35*14675a02SAndroid Build Coastguard Worker StatusOr<DataType> ConvertDataType(tf::DataType dtype) {
36*14675a02SAndroid Build Coastguard Worker   switch (dtype) {
37*14675a02SAndroid Build Coastguard Worker     case tf::DT_FLOAT:
38*14675a02SAndroid Build Coastguard Worker       return DT_FLOAT;
39*14675a02SAndroid Build Coastguard Worker     case tf::DT_DOUBLE:
40*14675a02SAndroid Build Coastguard Worker       return DT_DOUBLE;
41*14675a02SAndroid Build Coastguard Worker     case tf::DT_INT32:
42*14675a02SAndroid Build Coastguard Worker       return DT_INT32;
43*14675a02SAndroid Build Coastguard Worker     case tf::DT_INT64:
44*14675a02SAndroid Build Coastguard Worker       return DT_INT64;
45*14675a02SAndroid Build Coastguard Worker     case tf::DT_STRING:
46*14675a02SAndroid Build Coastguard Worker       return DT_STRING;
47*14675a02SAndroid Build Coastguard Worker     default:
48*14675a02SAndroid Build Coastguard Worker       return FCP_STATUS(INVALID_ARGUMENT)
49*14675a02SAndroid Build Coastguard Worker              << "Unsupported tf::DataType: " << dtype;
50*14675a02SAndroid Build Coastguard Worker   }
51*14675a02SAndroid Build Coastguard Worker }
52*14675a02SAndroid Build Coastguard Worker 
ConvertShape(const tf::TensorShape & shape)53*14675a02SAndroid Build Coastguard Worker TensorShape ConvertShape(const tf::TensorShape& shape) {
54*14675a02SAndroid Build Coastguard Worker   FCP_CHECK(shape.IsFullyDefined());
55*14675a02SAndroid Build Coastguard Worker   std::vector<size_t> dim_sizes;
56*14675a02SAndroid Build Coastguard Worker   for (auto dim_size : shape.dim_sizes()) {
57*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(dim_size >= 0);
58*14675a02SAndroid Build Coastguard Worker     dim_sizes.push_back(dim_size);
59*14675a02SAndroid Build Coastguard Worker   }
60*14675a02SAndroid Build Coastguard Worker   return TensorShape(dim_sizes.begin(), dim_sizes.end());
61*14675a02SAndroid Build Coastguard Worker }
62*14675a02SAndroid Build Coastguard Worker 
ConvertTensorSpec(const::tensorflow::TensorSpecProto & spec)63*14675a02SAndroid Build Coastguard Worker StatusOr<TensorSpec> ConvertTensorSpec(
64*14675a02SAndroid Build Coastguard Worker     const ::tensorflow::TensorSpecProto& spec) {
65*14675a02SAndroid Build Coastguard Worker   FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(spec.dtype()));
66*14675a02SAndroid Build Coastguard Worker   tf::TensorShape tf_shape;
67*14675a02SAndroid Build Coastguard Worker   if (!tf::TensorShape::BuildTensorShape(spec.shape(), &tf_shape).ok()) {
68*14675a02SAndroid Build Coastguard Worker     return FCP_STATUS(INVALID_ARGUMENT)
69*14675a02SAndroid Build Coastguard Worker            << "Unsupported tf::TensorShape: " << spec.shape().DebugString();
70*14675a02SAndroid Build Coastguard Worker   }
71*14675a02SAndroid Build Coastguard Worker   return TensorSpec(spec.name(), dtype, ConvertShape(tf_shape));
72*14675a02SAndroid Build Coastguard Worker }
73*14675a02SAndroid Build Coastguard Worker 
74*14675a02SAndroid Build Coastguard Worker // A primitive TensorData implementation that wraps the original
75*14675a02SAndroid Build Coastguard Worker // tf::Tensor data.
76*14675a02SAndroid Build Coastguard Worker // NumericTensorDataAdapter gets the ownership of the wrapped tensor, which
77*14675a02SAndroid Build Coastguard Worker // keeps the underlying data alive.
78*14675a02SAndroid Build Coastguard Worker class NumericTensorDataAdapter : public TensorData {
79*14675a02SAndroid Build Coastguard Worker  public:
NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)80*14675a02SAndroid Build Coastguard Worker   explicit NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
81*14675a02SAndroid Build Coastguard Worker       : tensor_(std::move(tensor)) {}
82*14675a02SAndroid Build Coastguard Worker 
83*14675a02SAndroid Build Coastguard Worker   // The source tf::Tensor has the data as one continuous blob.
byte_size() const84*14675a02SAndroid Build Coastguard Worker   size_t byte_size() const override { return tensor_->tensor_data().size(); }
data() const85*14675a02SAndroid Build Coastguard Worker   const void* data() const override { return tensor_->tensor_data().data(); }
86*14675a02SAndroid Build Coastguard Worker 
87*14675a02SAndroid Build Coastguard Worker  private:
88*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tf::Tensor> tensor_;
89*14675a02SAndroid Build Coastguard Worker };
90*14675a02SAndroid Build Coastguard Worker 
91*14675a02SAndroid Build Coastguard Worker // Similar to  NumericTensorDataAdapter but performs additional conversion
92*14675a02SAndroid Build Coastguard Worker // of the original tensor tstring values to string_view while keeping the
93*14675a02SAndroid Build Coastguard Worker // the tstring values owned by the original tensor.
94*14675a02SAndroid Build Coastguard Worker class StringTensorDataAdapter : public TensorData {
95*14675a02SAndroid Build Coastguard Worker  public:
StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)96*14675a02SAndroid Build Coastguard Worker   explicit StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
97*14675a02SAndroid Build Coastguard Worker       : tensor_(std::move(tensor)), string_views_(tensor_->NumElements()) {
98*14675a02SAndroid Build Coastguard Worker     auto string_values = tensor_->flat<tf::tstring>();
99*14675a02SAndroid Build Coastguard Worker     for (size_t i = 0; i < string_values.size(); ++i) {
100*14675a02SAndroid Build Coastguard Worker       string_views_[i] = string_values(i);
101*14675a02SAndroid Build Coastguard Worker     }
102*14675a02SAndroid Build Coastguard Worker   }
103*14675a02SAndroid Build Coastguard Worker 
byte_size() const104*14675a02SAndroid Build Coastguard Worker   size_t byte_size() const override {
105*14675a02SAndroid Build Coastguard Worker     return string_views_.size() * sizeof(string_view);
106*14675a02SAndroid Build Coastguard Worker   }
data() const107*14675a02SAndroid Build Coastguard Worker   const void* data() const override { return string_views_.data(); }
108*14675a02SAndroid Build Coastguard Worker 
109*14675a02SAndroid Build Coastguard Worker  private:
110*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tf::Tensor> tensor_;
111*14675a02SAndroid Build Coastguard Worker   std::vector<string_view> string_views_;
112*14675a02SAndroid Build Coastguard Worker };
113*14675a02SAndroid Build Coastguard Worker 
114*14675a02SAndroid Build Coastguard Worker // Conversion of tensor data for numeric data types, which can be
115*14675a02SAndroid Build Coastguard Worker // done by simply wrapping the original tensorflow tensor data.
116*14675a02SAndroid Build Coastguard Worker template <typename t>
ConvertTensorData(std::unique_ptr<tf::Tensor> tensor)117*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TensorData> ConvertTensorData(
118*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<tf::Tensor> tensor) {
119*14675a02SAndroid Build Coastguard Worker   return std::make_unique<NumericTensorDataAdapter>(std::move(tensor));
120*14675a02SAndroid Build Coastguard Worker }
121*14675a02SAndroid Build Coastguard Worker 
122*14675a02SAndroid Build Coastguard Worker // Specialization of ConvertTensorData for the DT_STRING data type.
123*14675a02SAndroid Build Coastguard Worker template <>
ConvertTensorData(std::unique_ptr<tf::Tensor> tensor)124*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TensorData> ConvertTensorData<string_view>(
125*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<tf::Tensor> tensor) {
126*14675a02SAndroid Build Coastguard Worker   return std::make_unique<StringTensorDataAdapter>(std::move(tensor));
127*14675a02SAndroid Build Coastguard Worker }
128*14675a02SAndroid Build Coastguard Worker 
ConvertTensor(std::unique_ptr<tf::Tensor> tensor)129*14675a02SAndroid Build Coastguard Worker StatusOr<Tensor> ConvertTensor(std::unique_ptr<tf::Tensor> tensor) {
130*14675a02SAndroid Build Coastguard Worker   FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tensor->dtype()));
131*14675a02SAndroid Build Coastguard Worker   TensorShape shape = ConvertShape(tensor->shape());
132*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<TensorData> data;
133*14675a02SAndroid Build Coastguard Worker   DTYPE_CASES(dtype, T, data = ConvertTensorData<T>(std::move(tensor)));
134*14675a02SAndroid Build Coastguard Worker   return Tensor::Create(dtype, std::move(shape), std::move(data));
135*14675a02SAndroid Build Coastguard Worker }
136*14675a02SAndroid Build Coastguard Worker 
137*14675a02SAndroid Build Coastguard Worker }  // namespace fcp::aggregation::tensorflow
138