xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/tensorflow/converters_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/tensorflow/converters.h"
18 
19 #include <initializer_list>
20 #include <memory>
21 #include <string>
22 
23 #include "gmock/gmock.h"
24 #include "gtest/gtest.h"
25 #include "fcp/aggregation/core/datatype.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/aggregation/core/tensor_spec.h"
28 #include "fcp/aggregation/testing/testing.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/testing/testing.h"
31 #include "tensorflow/core/framework/tensor.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.pb.h"
36 #include "tensorflow/core/protobuf/struct.pb.h"
37 
38 namespace fcp::aggregation::tensorflow {
39 namespace {
40 
41 namespace tf = ::tensorflow;
42 
CreateTfShape(std::initializer_list<int64_t> dim_sizes)43 tf::TensorShape CreateTfShape(std::initializer_list<int64_t> dim_sizes) {
44   tf::TensorShape shape;
45   EXPECT_TRUE(tf::TensorShape::BuildTensorShape(dim_sizes, &shape).ok());
46   return shape;
47 }
48 
CreateTfTensorSpec(const std::string & name,tf::DataType dtype,std::initializer_list<int64_t> dim_sizes)49 tf::TensorSpecProto CreateTfTensorSpec(
50     const std::string& name, tf::DataType dtype,
51     std::initializer_list<int64_t> dim_sizes) {
52   tf::TensorSpecProto spec;
53   spec.set_name(name);
54   spec.set_dtype(dtype);
55   for (auto dim_size : dim_sizes) {
56     spec.mutable_shape()->add_dim()->set_size(dim_size);
57   }
58   return spec;
59 }
60 
TEST(ConvertersTest,ConvertDataType_Success)61 TEST(ConvertersTest, ConvertDataType_Success) {
62   EXPECT_EQ(*ConvertDataType(tf::DT_FLOAT), DT_FLOAT);
63   EXPECT_EQ(*ConvertDataType(tf::DT_DOUBLE), DT_DOUBLE);
64   EXPECT_EQ(*ConvertDataType(tf::DT_INT32), DT_INT32);
65   EXPECT_EQ(*ConvertDataType(tf::DT_INT64), DT_INT64);
66   EXPECT_EQ(*ConvertDataType(tf::DT_STRING), DT_STRING);
67 }
68 
TEST(ConvertersTest,ConvertDataType_Unsupported)69 TEST(ConvertersTest, ConvertDataType_Unsupported) {
70   EXPECT_THAT(ConvertDataType(tf::DT_VARIANT), IsCode(INVALID_ARGUMENT));
71 }
72 
TEST(ConvertersTest,ConvertShape_Success)73 TEST(ConvertersTest, ConvertShape_Success) {
74   EXPECT_EQ(ConvertShape(CreateTfShape({})), TensorShape({}));
75   EXPECT_EQ(ConvertShape(CreateTfShape({1})), TensorShape({1}));
76   EXPECT_EQ(ConvertShape(CreateTfShape({2, 3})), TensorShape({2, 3}));
77 }
78 
TEST(ConvertersTest,ConvertTensorSpec_Success)79 TEST(ConvertersTest, ConvertTensorSpec_Success) {
80   auto tensor_spec =
81       ConvertTensorSpec(CreateTfTensorSpec("foo", tf::DT_FLOAT, {1, 2, 3}));
82   ASSERT_THAT(tensor_spec, IsOk());
83   EXPECT_EQ(tensor_spec->name(), "foo");
84   EXPECT_EQ(tensor_spec->dtype(), DT_FLOAT);
85   EXPECT_EQ(tensor_spec->shape(), TensorShape({1, 2, 3}));
86 }
87 
TEST(ConvertersTest,ConvertTensorSpec_UnsupportedDataType)88 TEST(ConvertersTest, ConvertTensorSpec_UnsupportedDataType) {
89   EXPECT_THAT(
90       ConvertTensorSpec(CreateTfTensorSpec("foo", tf::DT_VARIANT, {1, 2, 3})),
91       IsCode(INVALID_ARGUMENT));
92 }
93 
TEST(ConvertersTest,ConvertTensorSpec_UnsupportedShape)94 TEST(ConvertersTest, ConvertTensorSpec_UnsupportedShape) {
95   EXPECT_THAT(
96       ConvertTensorSpec(CreateTfTensorSpec("foo", tf::DT_FLOAT, {1, -1})),
97       IsCode(INVALID_ARGUMENT));
98 }
99 
TEST(ConvertersTest,ConvertTensor_Numeric)100 TEST(ConvertersTest, ConvertTensor_Numeric) {
101   tf::TensorProto tensor_proto = PARSE_TEXT_PROTO(R"pb(
102     dtype: DT_FLOAT
103     tensor_shape {
104       dim { size: 2 }
105       dim { size: 3 }
106     }
107     float_val: 1
108     float_val: 2
109     float_val: 3
110     float_val: 4
111     float_val: 5
112     float_val: 6
113   )pb");
114   auto tensor = std::make_unique<tf::Tensor>();
115   ASSERT_TRUE(tensor->FromProto(tensor_proto));
116   EXPECT_THAT(*ConvertTensor(std::move(tensor)),
117               IsTensor<float>({2, 3}, {1, 2, 3, 4, 5, 6}));
118 }
119 
TEST(ConvertersTest,ConvertTensor_String)120 TEST(ConvertersTest, ConvertTensor_String) {
121   tf::TensorProto tensor_proto = PARSE_TEXT_PROTO(R"pb(
122     dtype: DT_STRING
123     tensor_shape { dim { size: 3 } }
124     string_val: "abcd"
125     string_val: "foobar"
126     string_val: "zzzzzzzzzzzzzz"
127   )pb");
128   auto tensor = std::make_unique<tf::Tensor>();
129   ASSERT_TRUE(tensor->FromProto(tensor_proto));
130   EXPECT_THAT(*ConvertTensor(std::move(tensor)),
131               IsTensor<string_view>({3}, {"abcd", "foobar", "zzzzzzzzzzzzzz"}));
132 }
133 
TEST(ConvertersTest,ConvertTensor_ScalarString)134 TEST(ConvertersTest, ConvertTensor_ScalarString) {
135   tf::TensorProto tensor_proto = PARSE_TEXT_PROTO(R"pb(
136     dtype: DT_STRING
137     tensor_shape {}
138     string_val: "0123456789"
139   )pb");
140   auto tensor = std::make_unique<tf::Tensor>();
141   ASSERT_TRUE(tensor->FromProto(tensor_proto));
142   EXPECT_THAT(*ConvertTensor(std::move(tensor)),
143               IsTensor<string_view>({}, {"0123456789"}));
144 }
145 
TEST(ConvertersTest,ConvertTensor_UnsupportedDataType)146 TEST(ConvertersTest, ConvertTensor_UnsupportedDataType) {
147   auto tensor = std::make_unique<tf::Tensor>(tf::DT_VARIANT, CreateTfShape({}));
148   EXPECT_THAT(ConvertTensor(std::move(tensor)), IsCode(INVALID_ARGUMENT));
149 }
150 
151 }  // namespace
152 }  // namespace fcp::aggregation::tensorflow
153