1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2022 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/tensorflow/converters.h"
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <utility>
21*14675a02SAndroid Build Coastguard Worker #include <vector>
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Worker #include "absl/strings/string_view.h"
24*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/datatype.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/aggregation/core/tensor.pb.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
27*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor.h"
28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.h"
29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/framework/tensor_shape.pb.h"
30*14675a02SAndroid Build Coastguard Worker
31*14675a02SAndroid Build Coastguard Worker namespace fcp::aggregation::tensorflow {
32*14675a02SAndroid Build Coastguard Worker
33*14675a02SAndroid Build Coastguard Worker namespace tf = ::tensorflow;
34*14675a02SAndroid Build Coastguard Worker
ConvertDataType(tf::DataType dtype)35*14675a02SAndroid Build Coastguard Worker StatusOr<DataType> ConvertDataType(tf::DataType dtype) {
36*14675a02SAndroid Build Coastguard Worker switch (dtype) {
37*14675a02SAndroid Build Coastguard Worker case tf::DT_FLOAT:
38*14675a02SAndroid Build Coastguard Worker return DT_FLOAT;
39*14675a02SAndroid Build Coastguard Worker case tf::DT_DOUBLE:
40*14675a02SAndroid Build Coastguard Worker return DT_DOUBLE;
41*14675a02SAndroid Build Coastguard Worker case tf::DT_INT32:
42*14675a02SAndroid Build Coastguard Worker return DT_INT32;
43*14675a02SAndroid Build Coastguard Worker case tf::DT_INT64:
44*14675a02SAndroid Build Coastguard Worker return DT_INT64;
45*14675a02SAndroid Build Coastguard Worker case tf::DT_STRING:
46*14675a02SAndroid Build Coastguard Worker return DT_STRING;
47*14675a02SAndroid Build Coastguard Worker default:
48*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(INVALID_ARGUMENT)
49*14675a02SAndroid Build Coastguard Worker << "Unsupported tf::DataType: " << dtype;
50*14675a02SAndroid Build Coastguard Worker }
51*14675a02SAndroid Build Coastguard Worker }
52*14675a02SAndroid Build Coastguard Worker
ConvertShape(const tf::TensorShape & shape)53*14675a02SAndroid Build Coastguard Worker TensorShape ConvertShape(const tf::TensorShape& shape) {
54*14675a02SAndroid Build Coastguard Worker FCP_CHECK(shape.IsFullyDefined());
55*14675a02SAndroid Build Coastguard Worker std::vector<size_t> dim_sizes;
56*14675a02SAndroid Build Coastguard Worker for (auto dim_size : shape.dim_sizes()) {
57*14675a02SAndroid Build Coastguard Worker FCP_CHECK(dim_size >= 0);
58*14675a02SAndroid Build Coastguard Worker dim_sizes.push_back(dim_size);
59*14675a02SAndroid Build Coastguard Worker }
60*14675a02SAndroid Build Coastguard Worker return TensorShape(dim_sizes.begin(), dim_sizes.end());
61*14675a02SAndroid Build Coastguard Worker }
62*14675a02SAndroid Build Coastguard Worker
ConvertTensorSpec(const::tensorflow::TensorSpecProto & spec)63*14675a02SAndroid Build Coastguard Worker StatusOr<TensorSpec> ConvertTensorSpec(
64*14675a02SAndroid Build Coastguard Worker const ::tensorflow::TensorSpecProto& spec) {
65*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(spec.dtype()));
66*14675a02SAndroid Build Coastguard Worker tf::TensorShape tf_shape;
67*14675a02SAndroid Build Coastguard Worker if (!tf::TensorShape::BuildTensorShape(spec.shape(), &tf_shape).ok()) {
68*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(INVALID_ARGUMENT)
69*14675a02SAndroid Build Coastguard Worker << "Unsupported tf::TensorShape: " << spec.shape().DebugString();
70*14675a02SAndroid Build Coastguard Worker }
71*14675a02SAndroid Build Coastguard Worker return TensorSpec(spec.name(), dtype, ConvertShape(tf_shape));
72*14675a02SAndroid Build Coastguard Worker }
73*14675a02SAndroid Build Coastguard Worker
74*14675a02SAndroid Build Coastguard Worker // A primitive TensorData implementation that wraps the original
75*14675a02SAndroid Build Coastguard Worker // tf::Tensor data.
76*14675a02SAndroid Build Coastguard Worker // NumericTensorDataAdapter gets the ownership of the wrapped tensor, which
77*14675a02SAndroid Build Coastguard Worker // keeps the underlying data alive.
78*14675a02SAndroid Build Coastguard Worker class NumericTensorDataAdapter : public TensorData {
79*14675a02SAndroid Build Coastguard Worker public:
NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)80*14675a02SAndroid Build Coastguard Worker explicit NumericTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
81*14675a02SAndroid Build Coastguard Worker : tensor_(std::move(tensor)) {}
82*14675a02SAndroid Build Coastguard Worker
83*14675a02SAndroid Build Coastguard Worker // The source tf::Tensor has the data as one continuous blob.
byte_size() const84*14675a02SAndroid Build Coastguard Worker size_t byte_size() const override { return tensor_->tensor_data().size(); }
data() const85*14675a02SAndroid Build Coastguard Worker const void* data() const override { return tensor_->tensor_data().data(); }
86*14675a02SAndroid Build Coastguard Worker
87*14675a02SAndroid Build Coastguard Worker private:
88*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tf::Tensor> tensor_;
89*14675a02SAndroid Build Coastguard Worker };
90*14675a02SAndroid Build Coastguard Worker
91*14675a02SAndroid Build Coastguard Worker // Similar to NumericTensorDataAdapter but performs additional conversion
92*14675a02SAndroid Build Coastguard Worker // of the original tensor tstring values to string_view while keeping the
93*14675a02SAndroid Build Coastguard Worker // the tstring values owned by the original tensor.
94*14675a02SAndroid Build Coastguard Worker class StringTensorDataAdapter : public TensorData {
95*14675a02SAndroid Build Coastguard Worker public:
StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)96*14675a02SAndroid Build Coastguard Worker explicit StringTensorDataAdapter(std::unique_ptr<tf::Tensor> tensor)
97*14675a02SAndroid Build Coastguard Worker : tensor_(std::move(tensor)), string_views_(tensor_->NumElements()) {
98*14675a02SAndroid Build Coastguard Worker auto string_values = tensor_->flat<tf::tstring>();
99*14675a02SAndroid Build Coastguard Worker for (size_t i = 0; i < string_values.size(); ++i) {
100*14675a02SAndroid Build Coastguard Worker string_views_[i] = string_values(i);
101*14675a02SAndroid Build Coastguard Worker }
102*14675a02SAndroid Build Coastguard Worker }
103*14675a02SAndroid Build Coastguard Worker
byte_size() const104*14675a02SAndroid Build Coastguard Worker size_t byte_size() const override {
105*14675a02SAndroid Build Coastguard Worker return string_views_.size() * sizeof(string_view);
106*14675a02SAndroid Build Coastguard Worker }
data() const107*14675a02SAndroid Build Coastguard Worker const void* data() const override { return string_views_.data(); }
108*14675a02SAndroid Build Coastguard Worker
109*14675a02SAndroid Build Coastguard Worker private:
110*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tf::Tensor> tensor_;
111*14675a02SAndroid Build Coastguard Worker std::vector<string_view> string_views_;
112*14675a02SAndroid Build Coastguard Worker };
113*14675a02SAndroid Build Coastguard Worker
114*14675a02SAndroid Build Coastguard Worker // Conversion of tensor data for numeric data types, which can be
115*14675a02SAndroid Build Coastguard Worker // done by simply wrapping the original tensorflow tensor data.
116*14675a02SAndroid Build Coastguard Worker template <typename t>
ConvertTensorData(std::unique_ptr<tf::Tensor> tensor)117*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TensorData> ConvertTensorData(
118*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tf::Tensor> tensor) {
119*14675a02SAndroid Build Coastguard Worker return std::make_unique<NumericTensorDataAdapter>(std::move(tensor));
120*14675a02SAndroid Build Coastguard Worker }
121*14675a02SAndroid Build Coastguard Worker
122*14675a02SAndroid Build Coastguard Worker // Specialization of ConvertTensorData for the DT_STRING data type.
123*14675a02SAndroid Build Coastguard Worker template <>
ConvertTensorData(std::unique_ptr<tf::Tensor> tensor)124*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TensorData> ConvertTensorData<string_view>(
125*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tf::Tensor> tensor) {
126*14675a02SAndroid Build Coastguard Worker return std::make_unique<StringTensorDataAdapter>(std::move(tensor));
127*14675a02SAndroid Build Coastguard Worker }
128*14675a02SAndroid Build Coastguard Worker
ConvertTensor(std::unique_ptr<tf::Tensor> tensor)129*14675a02SAndroid Build Coastguard Worker StatusOr<Tensor> ConvertTensor(std::unique_ptr<tf::Tensor> tensor) {
130*14675a02SAndroid Build Coastguard Worker FCP_ASSIGN_OR_RETURN(DataType dtype, ConvertDataType(tensor->dtype()));
131*14675a02SAndroid Build Coastguard Worker TensorShape shape = ConvertShape(tensor->shape());
132*14675a02SAndroid Build Coastguard Worker std::unique_ptr<TensorData> data;
133*14675a02SAndroid Build Coastguard Worker DTYPE_CASES(dtype, T, data = ConvertTensorData<T>(std::move(tensor)));
134*14675a02SAndroid Build Coastguard Worker return Tensor::Create(dtype, std::move(shape), std::move(data));
135*14675a02SAndroid Build Coastguard Worker }
136*14675a02SAndroid Build Coastguard Worker
137*14675a02SAndroid Build Coastguard Worker } // namespace fcp::aggregation::tensorflow
138