xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/composite_key_combiner_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/aggregation/core/composite_key_combiner.h"
17 
18 #include <cstdint>
19 #include <memory>
20 #include <vector>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "fcp/aggregation/core/agg_vector.h"
25 #include "fcp/aggregation/core/datatype.h"
26 #include "fcp/aggregation/core/input_tensor_list.h"
27 #include "fcp/aggregation/core/tensor.h"
28 #include "fcp/aggregation/core/tensor.pb.h"
29 #include "fcp/aggregation/core/tensor_shape.h"
30 #include "fcp/aggregation/testing/test_data.h"
31 #include "fcp/aggregation/testing/testing.h"
32 #include "fcp/base/monitoring.h"
33 #include "fcp/testing/testing.h"
34 
35 namespace fcp {
36 namespace aggregation {
37 namespace {
38 
39 using testing::Eq;
40 using testing::IsEmpty;
41 
TEST(CompositeKeyCombinerTest,EmptyInput_Invalid)42 TEST(CompositeKeyCombinerTest, EmptyInput_Invalid) {
43   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT});
44   StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({}));
45   ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
46 }
47 
TEST(CompositeKeyCombinerTest,InputWithWrongShapeTensor_Invalid)48 TEST(CompositeKeyCombinerTest, InputWithWrongShapeTensor_Invalid) {
49   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_INT32});
50   Tensor t1 =
51       Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
52           .value();
53   Tensor t2 =
54       Tensor::Create(DT_INT32, {4}, CreateTestData<int32_t>({1, 2, 3, 4}))
55           .value();
56   StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1, &t2}));
57   ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
58 }
59 
TEST(CompositeKeyCombinerTest,InputWithTooFewTensors_Invalid)60 TEST(CompositeKeyCombinerTest, InputWithTooFewTensors_Invalid) {
61   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_INT32});
62   Tensor t1 =
63       Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
64           .value();
65   Tensor t2 =
66       Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
67   StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1}));
68   ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
69 }
70 
TEST(CompositeKeyCombinerTest,InputWithTooManyTensors_Invalid)71 TEST(CompositeKeyCombinerTest, InputWithTooManyTensors_Invalid) {
72   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_INT32});
73   Tensor t1 =
74       Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
75           .value();
76   Tensor t2 =
77       Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
78   Tensor t3 =
79       Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({4, 5, 6})).value();
80   StatusOr<Tensor> result =
81       combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
82   ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
83 }
84 
TEST(CompositeKeyCombinerTest,InputWithWrongTypes_Invalid)85 TEST(CompositeKeyCombinerTest, InputWithWrongTypes_Invalid) {
86   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT, DT_STRING});
87   Tensor t1 =
88       Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
89           .value();
90   Tensor t2 =
91       Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
92   StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1, &t2}));
93   ASSERT_THAT(result, IsCode(INVALID_ARGUMENT));
94 }
95 
TEST(CompositeKeyCombinerTest,OutputBeforeAccumulate_Empty)96 TEST(CompositeKeyCombinerTest, OutputBeforeAccumulate_Empty) {
97   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT});
98   StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
99   ASSERT_OK(output);
100   EXPECT_THAT(output.value(), IsEmpty());
101 }
102 
TEST(CompositeKeyCombinerTest,AccumulateAndOutput_SingleElement)103 TEST(CompositeKeyCombinerTest, AccumulateAndOutput_SingleElement) {
104   CompositeKeyCombiner combiner(std::vector<DataType>{DT_FLOAT});
105   Tensor t1 =
106       Tensor::Create(DT_FLOAT, {1}, CreateTestData<float>({1.3})).value();
107   StatusOr<Tensor> result = combiner.Accumulate(InputTensorList({&t1}));
108   ASSERT_OK(result);
109   EXPECT_THAT(result.value(), IsTensor<int64_t>({1}, {0}));
110 
111   StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
112   ASSERT_OK(output);
113   EXPECT_THAT(output.value().size(), Eq(1));
114   EXPECT_THAT(output.value()[0], IsTensor<float>({1}, {1.3}));
115 }
116 
TEST(CompositeKeyCombinerTest,AccumulateAndOutput_NumericTypes)117 TEST(CompositeKeyCombinerTest, AccumulateAndOutput_NumericTypes) {
118   CompositeKeyCombiner combiner(
119       std::vector<DataType>{DT_FLOAT, DT_INT32, DT_INT64});
120   Tensor t1 =
121       Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
122           .value();
123   Tensor t2 =
124       Tensor::Create(DT_INT32, {3}, CreateTestData<int32_t>({1, 2, 3})).value();
125   Tensor t3 =
126       Tensor::Create(DT_INT64, {3}, CreateTestData<int64_t>({4, 5, 6})).value();
127   StatusOr<Tensor> result =
128       combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
129   ASSERT_OK(result);
130   EXPECT_THAT(result.value(), IsTensor<int64_t>({3}, {0, 1, 2}));
131 
132   StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
133   ASSERT_OK(output);
134   EXPECT_THAT(output.value().size(), Eq(3));
135   EXPECT_THAT(output.value()[0], IsTensor<float>({3}, {1.1, 1.2, 1.3}));
136   EXPECT_THAT(output.value()[1], IsTensor<int32_t>({3}, {1, 2, 3}));
137   EXPECT_THAT(output.value()[2], IsTensor<int64_t>({3}, {4, 5, 6}));
138 }
139 
TEST(CompositeKeyCombinerTest,NumericTypes_SameKeysResultInSameOrdinalsAcrossAccumulateCalls)140 TEST(CompositeKeyCombinerTest,
141      NumericTypes_SameKeysResultInSameOrdinalsAcrossAccumulateCalls) {
142   CompositeKeyCombiner combiner(
143       std::vector<DataType>{DT_FLOAT, DT_INT32, DT_INT64});
144   Tensor t1 =
145       Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({1.1, 1.2, 1.1, 1.2}))
146           .value();
147   Tensor t2 =
148       Tensor::Create(DT_INT32, {4}, CreateTestData<int32_t>({1, 2, 3, 2}))
149           .value();
150   Tensor t3 =
151       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({4, 5, 6, 5}))
152           .value();
153   StatusOr<Tensor> result1 =
154       combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
155   ASSERT_OK(result1);
156   EXPECT_THAT(result1.value(), IsTensor<int64_t>({4}, {0, 1, 2, 1}));
157 
158   // Across different calls to Accumulate, tensors can have different shape.
159   Tensor t4 = Tensor::Create(DT_FLOAT, {5},
160                              CreateTestData<float>({1.2, 1.1, 1.1, 1.1, 1.2}))
161                   .value();
162   Tensor t5 =
163       Tensor::Create(DT_INT32, {5}, CreateTestData<int32_t>({2, 3, 2, 3, 2}))
164           .value();
165   Tensor t6 =
166       Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({5, 6, 5, 6, 5}))
167           .value();
168   StatusOr<Tensor> result2 =
169       combiner.Accumulate(InputTensorList({&t4, &t5, &t6}));
170   ASSERT_OK(result2);
171   EXPECT_THAT(result2.value(), IsTensor<int64_t>({5}, {1, 2, 3, 2, 1}));
172 
173   StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
174   ASSERT_OK(output);
175   EXPECT_THAT(output.value().size(), Eq(3));
176   EXPECT_THAT(output.value()[0], IsTensor<float>({4}, {1.1, 1.2, 1.1, 1.1}));
177   EXPECT_THAT(output.value()[1], IsTensor<int32_t>({4}, {1, 2, 3, 2}));
178   EXPECT_THAT(output.value()[2], IsTensor<int64_t>({4}, {4, 5, 6, 5}));
179 }
180 
TEST(CompositeKeyCombinerTest,AccumulateAndOutput_StringTypes)181 TEST(CompositeKeyCombinerTest, AccumulateAndOutput_StringTypes) {
182   CompositeKeyCombiner combiner(
183       std::vector<DataType>{DT_FLOAT, DT_STRING, DT_STRING});
184   Tensor t1 =
185       Tensor::Create(DT_FLOAT, {3}, CreateTestData<float>({1.1, 1.2, 1.3}))
186           .value();
187   Tensor t2 = Tensor::Create(DT_STRING, {3},
188                              CreateTestData<string_view>({"abc", "de", ""}))
189                   .value();
190   Tensor t3 =
191       Tensor::Create(DT_STRING, {3},
192                      CreateTestData<string_view>({"fghi", "jklmn", "o"}))
193           .value();
194   StatusOr<Tensor> result =
195       combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
196   ASSERT_OK(result);
197   EXPECT_THAT(result.value(), IsTensor<int64_t>({3}, {0, 1, 2}));
198 
199   StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
200   ASSERT_OK(output);
201   EXPECT_THAT(output.value().size(), Eq(3));
202   EXPECT_THAT(output.value()[0], IsTensor<float>({3}, {1.1, 1.2, 1.3}));
203   EXPECT_THAT(output.value()[1], IsTensor<string_view>({3}, {"abc", "de", ""}));
204   EXPECT_THAT(output.value()[2],
205               IsTensor<string_view>({3}, {"fghi", "jklmn", "o"}));
206 }
207 
TEST(CompositeKeyCombinerTest,StringTypes_SameCompositeKeysResultInSameOrdinalsAcrossAccumulateCalls)208 TEST(CompositeKeyCombinerTest,
209      StringTypes_SameCompositeKeysResultInSameOrdinalsAcrossAccumulateCalls) {
210   CompositeKeyCombiner combiner(
211       std::vector<DataType>{DT_FLOAT, DT_STRING, DT_STRING});
212   Tensor t1 =
213       Tensor::Create(DT_FLOAT, {4}, CreateTestData<float>({1.1, 1.2, 1.2, 1.3}))
214           .value();
215   Tensor t2 =
216       Tensor::Create(DT_STRING, {4},
217                      CreateTestData<string_view>({"abc", "de", "de", ""}))
218           .value();
219   Tensor t3 = Tensor::Create(
220                   DT_STRING, {4},
221                   CreateTestData<string_view>({"fghi", "jklmn", "jklmn", "o"}))
222                   .value();
223   StatusOr<Tensor> result1 =
224       combiner.Accumulate(InputTensorList({&t1, &t2, &t3}));
225   ASSERT_OK(result1);
226   EXPECT_THAT(result1.value(), IsTensor<int64_t>({4}, {0, 1, 1, 2}));
227 
228   // Across different calls to Accumulate, tensors can have different shape.
229   Tensor t4 = Tensor::Create(DT_FLOAT, {5},
230                              CreateTestData<float>({1.3, 1.4, 1.1, 1.2, 1.1}))
231                   .value();
232   Tensor t5 = Tensor::Create(
233                   DT_STRING, {5},
234                   CreateTestData<string_view>({"", "abc", "abc", "de", "abc"}))
235                   .value();
236   Tensor t6 =
237       Tensor::Create(
238           DT_STRING, {5},
239           CreateTestData<string_view>({"o", "pqrs", "fghi", "jklmn", "fghi"}))
240           .value();
241   StatusOr<Tensor> result2 =
242       combiner.Accumulate(InputTensorList({&t4, &t5, &t6}));
243   ASSERT_OK(result2);
244   EXPECT_THAT(result2.value(), IsTensor<int64_t>({5}, {2, 3, 0, 1, 0}));
245 
246   StatusOr<std::vector<Tensor>> output = combiner.GetOutputKeys();
247   EXPECT_THAT(output.value().size(), Eq(3));
248   EXPECT_THAT(output.value()[0], IsTensor<float>({4}, {1.1, 1.2, 1.3, 1.4}));
249   EXPECT_THAT(output.value()[1],
250               IsTensor<string_view>({4}, {"abc", "de", "", "abc"}));
251   EXPECT_THAT(output.value()[2],
252               IsTensor<string_view>({4}, {"fghi", "jklmn", "o", "pqrs"}));
253 }
254 
255 }  // namespace
256 }  // namespace aggregation
257 }  // namespace fcp
258