xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/one_dim_grouping_aggregator_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/aggregation/core/one_dim_grouping_aggregator.h"
17 
18 #include <climits>
19 #include <cstdint>
20 #include <utility>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "fcp/aggregation/core/agg_vector.h"
25 #include "fcp/aggregation/core/tensor.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/aggregation/testing/test_data.h"
28 #include "fcp/aggregation/testing/testing.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/testing/testing.h"
31 
32 namespace fcp {
33 namespace aggregation {
34 namespace {
35 
36 using testing::Eq;
37 using testing::IsFalse;
38 using testing::IsTrue;
39 
40 // A simple Sum Aggregator
41 template <typename T>
42 class SumGroupingAggregator final : public OneDimGroupingAggregator<T> {
43  public:
44   using OneDimGroupingAggregator<T>::OneDimGroupingAggregator;
45   using OneDimGroupingAggregator<T>::data;
46 
47  private:
AggregateVectorByOrdinals(const AggVector<int64_t> & ordinals_vector,const AggVector<T> & value_vector)48   void AggregateVectorByOrdinals(const AggVector<int64_t>& ordinals_vector,
49                                  const AggVector<T>& value_vector) override {
50     auto value_it = value_vector.begin();
51     for (auto o : ordinals_vector) {
52       int64_t output_index = o.value;
53       // If this function returned a failed Status at this point, the
54       // data_vector_ may have already been partially modified, leaving the
55       // GroupingAggregator in a bad state. Thus, check that the indices of the
56       // ordinals tensor and the data tensor match with FCP_CHECK instead.
57       //
58       // TODO(team): Revisit the constraint that the indices of the
59       // values must match the indices of the ordinals when sparse tensors are
60       // implemented. It may be possible for the value to be omitted for a given
61       // ordinal in which case the default value should be used.
62       FCP_CHECK(value_it.index() == o.index)
63           << "Indices in AggVector of ordinals and AggVector of values "
64              "are mismatched.";
65       // Delegate the actual aggregation to the specific aggregation
66       // intrinsic implementation.
67       AggregateValue(output_index, value_it++.value());
68     }
69   }
70 
AggregateVector(const AggVector<T> & value_vector)71   void AggregateVector(const AggVector<T>& value_vector) override {
72     for (auto it : value_vector) {
73       AggregateValue(it.index, it.value);
74     }
75   }
76 
AggregateValue(int64_t i,T value)77   inline void AggregateValue(int64_t i, T value) { data()[i] += value; }
78 
GetDefaultValue()79   T GetDefaultValue() override { return static_cast<T>(0); }
80 };
81 
82 // A simple Min Aggregator that works for int32_t
83 class MinGroupingAggregator final : public OneDimGroupingAggregator<int32_t> {
84  public:
85   using OneDimGroupingAggregator<int32_t>::OneDimGroupingAggregator;
86   using OneDimGroupingAggregator<int32_t>::data;
87 
88  private:
AggregateVectorByOrdinals(const AggVector<int64_t> & ordinals_vector,const AggVector<int32_t> & value_vector)89   void AggregateVectorByOrdinals(
90       const AggVector<int64_t>& ordinals_vector,
91       const AggVector<int32_t>& value_vector) override {
92     auto value_it = value_vector.begin();
93     for (auto o : ordinals_vector) {
94       int64_t output_index = o.value;
95       // If this function returned a failed Status at this point, the
96       // data_vector_ may have already been partially modified, leaving the
97       // GroupingAggregator in a bad state. Thus, check that the indices of the
98       // ordinals tensor and the data tensor match with FCP_CHECK instead.
99       //
100       // TODO(team): Revisit the constraint that the indices of the
101       // values must match the indices of the ordinals when sparse tensors are
102       // implemented. It may be possible for the value to be omitted for a given
103       // ordinal in which case the default value should be used.
104       FCP_CHECK(value_it.index() == o.index)
105           << "Indices in AggVector of ordinals and AggVector of values "
106              "are mismatched.";
107       // Delegate the actual aggregation to the specific aggregation
108       // intrinsic implementation.
109       AggregateValue(output_index, value_it++.value());
110     }
111   }
112 
AggregateVector(const AggVector<int32_t> & value_vector)113   void AggregateVector(const AggVector<int32_t>& value_vector) override {
114     for (auto it : value_vector) {
115       AggregateValue(it.index, it.value);
116     }
117   }
118 
AggregateValue(int64_t i,int32_t value)119   inline void AggregateValue(int64_t i, int32_t value) {
120     if (value < data()[i]) {
121       data()[i] = value;
122     }
123   }
GetDefaultValue()124   int32_t GetDefaultValue() override { return INT_MAX; }
125 };
126 
TEST(GroupingAggregatorTest,EmptyReport)127 TEST(GroupingAggregatorTest, EmptyReport) {
128   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
129   auto result = std::move(aggregator).Report();
130   EXPECT_THAT(result, IsOk());
131   EXPECT_THAT(result->size(), Eq(0));
132 }
133 
TEST(GroupingAggregatorTest,ScalarAggregation_Succeeds)134 TEST(GroupingAggregatorTest, ScalarAggregation_Succeeds) {
135   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
136   Tensor ordinal =
137       Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
138   Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
139   Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
140   Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
141   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t1}), IsOk());
142   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t2}), IsOk());
143   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t3}), IsOk());
144   EXPECT_THAT(aggregator.CanReport(), IsTrue());
145 
146   auto result = std::move(aggregator).Report();
147   EXPECT_THAT(result, IsOk());
148   EXPECT_THAT(result.value().size(), Eq(1));
149   // Verify the resulting tensor.
150   EXPECT_THAT(result.value()[0], IsTensor({1}, {6}));
151 }
152 
TEST(GroupingAggregatorTest,DenseAggregation_Succeeds)153 TEST(GroupingAggregatorTest, DenseAggregation_Succeeds) {
154   const TensorShape shape = {4};
155   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
156   Tensor ordinals =
157       Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({0, 1, 2, 3}))
158           .value();
159   Tensor t1 =
160       Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
161   Tensor t2 =
162       Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
163   Tensor t3 =
164       Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
165   EXPECT_THAT(aggregator.Accumulate({&ordinals, &t1}), IsOk());
166   EXPECT_THAT(aggregator.Accumulate({&ordinals, &t2}), IsOk());
167   EXPECT_THAT(aggregator.Accumulate({&ordinals, &t3}), IsOk());
168   EXPECT_THAT(aggregator.CanReport(), IsTrue());
169   EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
170 
171   auto result = std::move(aggregator).Report();
172   EXPECT_THAT(result, IsOk());
173   EXPECT_THAT(result->size(), Eq(1));
174   // Verify the resulting tensor.
175   EXPECT_THAT(result.value()[0], IsTensor(shape, {14, 19, 23, 49}));
176   // Also ensure that the resulting tensor is dense.
177   EXPECT_TRUE(result.value()[0].is_dense());
178 }
179 
TEST(GroupingAggregatorTest,DifferentOrdinalsPerAccumulate_Succeeds)180 TEST(GroupingAggregatorTest, DifferentOrdinalsPerAccumulate_Succeeds) {
181   const TensorShape shape = {4};
182   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
183   Tensor t1_ordinals =
184       Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({3, 3, 2, 0}))
185           .value();
186   Tensor t1 =
187       Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
188   EXPECT_THAT(aggregator.Accumulate({&t1_ordinals, &t1}), IsOk());
189   // Totals: [27, 0, 15, 4]
190   Tensor t2_ordinals =
191       Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({1, 0, 1, 4}))
192           .value();
193   Tensor t2 =
194       Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
195   EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
196   // Totals: [32, 11, 15, 4, 2]
197   Tensor t3_ordinals =
198       Tensor::Create(DT_INT64, shape, CreateTestData<int64_t>({2, 2, 5, 1}))
199           .value();
200   Tensor t3 =
201       Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
202   EXPECT_THAT(aggregator.Accumulate({&t3_ordinals, &t3}), IsOk());
203   // Totals: [32, 31, 29, 4, 2, 7]
204   EXPECT_THAT(aggregator.CanReport(), IsTrue());
205   EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
206 
207   auto result = std::move(aggregator).Report();
208   EXPECT_THAT(result, IsOk());
209   EXPECT_THAT(result.value().size(), Eq(1));
210   // Verify the resulting tensor.
211   EXPECT_THAT(result.value()[0], IsTensor({6}, {32, 31, 29, 4, 2, 7}));
212   // Also ensure that the resulting tensor is dense.
213   EXPECT_TRUE(result.value()[0].is_dense());
214 }
215 
TEST(GroupingAggregatorTest,DifferentShapesPerAccumulate_Succeeds)216 TEST(GroupingAggregatorTest, DifferentShapesPerAccumulate_Succeeds) {
217   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
218   Tensor t1_ordinals =
219       Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
220   Tensor t1 = Tensor::Create(DT_INT32, {2}, CreateTestData({17, 3})).value();
221   EXPECT_THAT(aggregator.Accumulate({&t1_ordinals, &t1}), IsOk());
222   // Totals: [3, 0, 17]
223   Tensor t2_ordinals =
224       Tensor::Create(DT_INT64, {6}, CreateTestData<int64_t>({1, 0, 1, 4, 3, 0}))
225           .value();
226   Tensor t2 =
227       Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, 5}))
228           .value();
229   EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
230   // Totals: [13, 23, 17, 4, 2]
231   Tensor t3_ordinals =
232       Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
233           .value();
234   Tensor t3 =
235       Tensor::Create(DT_INT32, {5}, CreateTestData({3, 11, 7, 6, 3})).value();
236   EXPECT_THAT(aggregator.Accumulate({&t3_ordinals, &t3}), IsOk());
237   // Totals: [13, 30, 31, 4, 2]
238   EXPECT_THAT(aggregator.CanReport(), IsTrue());
239   EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
240 
241   auto result = std::move(aggregator).Report();
242   EXPECT_THAT(result, IsOk());
243   EXPECT_THAT(result.value().size(), Eq(1));
244   // Verify the resulting tensor.
245   EXPECT_THAT(result.value()[0], IsTensor({5}, {19, 30, 31, 4, 5}));
246   // Also ensure that the resulting tensor is dense.
247   EXPECT_TRUE(result.value()[0].is_dense());
248 }
249 
TEST(GroupingAggregatorTest,DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds)250 TEST(GroupingAggregatorTest,
251      DifferentShapesPerAccumulate_NonzeroDefaultValue_Succeeds) {
252   // Use a MinGroupingAggregator which has a non-zero default value so we can
253   // test that when the output grows, elements are set to the default value.
254   MinGroupingAggregator aggregator(DT_INT32);
255   Tensor t1_ordinals =
256       Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
257   Tensor t1 = Tensor::Create(DT_INT32, {2}, CreateTestData({17, 3})).value();
258   EXPECT_THAT(aggregator.Accumulate({&t1_ordinals, &t1}), IsOk());
259   // Totals: [3, INT_MAX, 17]
260   Tensor t2_ordinals =
261       Tensor::Create(DT_INT64, {6}, CreateTestData<int64_t>({0, 0, 0, 4, 4, 0}))
262           .value();
263   Tensor t2 =
264       Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, -50}))
265           .value();
266   EXPECT_THAT(aggregator.Accumulate({&t2_ordinals, &t2}), IsOk());
267   // Totals: [-50, INT_MAX, 17, INT_MAX, 2]
268   Tensor t3_ordinals =
269       Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
270           .value();
271   Tensor t3 =
272       Tensor::Create(DT_INT32, {5}, CreateTestData({33, 11, 7, 6, 3})).value();
273   EXPECT_THAT(aggregator.Accumulate({&t3_ordinals, &t3}), IsOk());
274   // Totals: [-50, 7, 11, INT_MAX, 2]
275   EXPECT_THAT(aggregator.CanReport(), IsTrue());
276   EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
277 
278   auto result = std::move(aggregator).Report();
279   EXPECT_THAT(result, IsOk());
280   EXPECT_THAT(result.value().size(), Eq(1));
281   // Verify the resulting tensor.
282   EXPECT_THAT(result.value()[0], IsTensor({5}, {-50, 7, 11, INT_MAX, 2}));
283   // Also ensure that the resulting tensor is dense.
284   EXPECT_TRUE(result.value()[0].is_dense());
285 }
286 
TEST(GroupingAggregatorTest,Merge_Succeeds)287 TEST(GroupingAggregatorTest, Merge_Succeeds) {
288   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
289   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
290   Tensor ordinal =
291       Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
292   Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
293   Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
294   Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
295   EXPECT_THAT(aggregator1.Accumulate({&ordinal, &t1}), IsOk());
296   EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t2}), IsOk());
297   EXPECT_THAT(aggregator2.Accumulate({&ordinal, &t3}), IsOk());
298 
299   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
300   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
301   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
302 
303   auto result = std::move(aggregator1).Report();
304   EXPECT_THAT(result, IsOk());
305   EXPECT_THAT(result.value().size(), Eq(1));
306   EXPECT_THAT(result.value()[0], IsTensor({1}, {6}));
307 }
308 
TEST(GroupingAggregatorTest,Merge_BothEmpty_Succeeds)309 TEST(GroupingAggregatorTest, Merge_BothEmpty_Succeeds) {
310   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
311   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
312 
313   // Merge the two empty aggregators together.
314   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
315   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
316   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(0));
317 
318   auto result = std::move(aggregator1).Report();
319   EXPECT_THAT(result, IsOk());
320   EXPECT_THAT(result->size(), Eq(0));
321 }
322 
TEST(GroupingAggregatorTest,Merge_ThisOutputEmpty_Succeeds)323 TEST(GroupingAggregatorTest, Merge_ThisOutputEmpty_Succeeds) {
324   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
325   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
326 
327   Tensor t1_ordinals =
328       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
329           .value();
330   Tensor t1 =
331       Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
332   EXPECT_THAT(aggregator2.Accumulate({&t1_ordinals, &t1}), IsOk());
333   // aggregator2 totals: [27, 0, 15, 4]
334   Tensor t2_ordinals =
335       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
336           .value();
337   Tensor t2 =
338       Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
339   EXPECT_THAT(aggregator2.Accumulate({&t2_ordinals, &t2}), IsOk());
340   // aggregator2 totals: [32, 11, 15, 4, 2]
341 
342   // Merge aggregator2 into aggregator1 which has not received any inputs.
343   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
344   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
345   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(2));
346 
347   auto result = std::move(aggregator1).Report();
348   EXPECT_THAT(result, IsOk());
349   EXPECT_THAT(result.value().size(), Eq(1));
350   // Verify the resulting tensor.
351   EXPECT_THAT(result.value()[0], IsTensor({5}, {32, 11, 15, 4, 2}));
352   // Also ensure that the resulting tensor is dense.
353   EXPECT_TRUE(result.value()[0].is_dense());
354 }
355 
TEST(GroupingAggregatorTest,Merge_OtherOutputEmpty_Succeeds)356 TEST(GroupingAggregatorTest, Merge_OtherOutputEmpty_Succeeds) {
357   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
358   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
359 
360   Tensor t1_ordinals =
361       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
362           .value();
363   Tensor t1 =
364       Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
365   EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
366   // aggregator1 totals: [27, 0, 15, 4]
367   Tensor t2_ordinals =
368       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
369           .value();
370   Tensor t2 =
371       Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
372   EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
373   // aggregator1 totals: [32, 11, 15, 4, 2]
374 
375   // Merge with aggregator2 which has not received any inputs.
376   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
377   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
378   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(2));
379 
380   auto result = std::move(aggregator1).Report();
381   EXPECT_THAT(result, IsOk());
382   EXPECT_THAT(result.value().size(), Eq(1));
383   // Verify the resulting tensor.
384   EXPECT_THAT(result.value()[0], IsTensor({5}, {32, 11, 15, 4, 2}));
385   // Also ensure that the resulting tensor is dense.
386   EXPECT_TRUE(result.value()[0].is_dense());
387 }
388 
TEST(GroupingAggregatorTest,Merge_OtherOutputHasFewerElements_Succeeds)389 TEST(GroupingAggregatorTest, Merge_OtherOutputHasFewerElements_Succeeds) {
390   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
391   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
392 
393   Tensor t1_ordinals =
394       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
395           .value();
396   Tensor t1 =
397       Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
398   EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
399   // aggregator1 totals: [27, 0, 15, 4]
400   Tensor t2_ordinals =
401       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
402           .value();
403   Tensor t2 =
404       Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
405   EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
406   // aggregator1 totals: [32, 11, 15, 4, 2]
407 
408   Tensor t3_ordinals =
409       Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 2})).value();
410   Tensor t3 = Tensor::Create(DT_INT32, {2}, CreateTestData({3, 11})).value();
411   EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
412   // aggregator2 totals: [0, 0, 14]
413 
414   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
415   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
416   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
417 
418   auto result = std::move(aggregator1).Report();
419   EXPECT_THAT(result, IsOk());
420   EXPECT_THAT(result.value().size(), Eq(1));
421   // Verify the resulting tensor.
422   EXPECT_THAT(result.value()[0], IsTensor({5}, {32, 11, 29, 4, 2}));
423   // Also ensure that the resulting tensor is dense.
424   EXPECT_TRUE(result.value()[0].is_dense());
425 }
426 
TEST(GroupingAggregatorTest,Merge_OtherOutputHasMoreElements_Succeeds)427 TEST(GroupingAggregatorTest, Merge_OtherOutputHasMoreElements_Succeeds) {
428   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
429   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
430 
431   Tensor t1_ordinals =
432       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({3, 3, 2, 0}))
433           .value();
434   Tensor t1 =
435       Tensor::Create(DT_INT32, {4}, CreateTestData({1, 3, 15, 27})).value();
436   EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
437   // aggregator1 totals: [27, 0, 15, 4]
438   Tensor t2_ordinals =
439       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({1, 0, 1, 4}))
440           .value();
441   Tensor t2 =
442       Tensor::Create(DT_INT32, {4}, CreateTestData({10, 5, 1, 2})).value();
443   EXPECT_THAT(aggregator1.Accumulate({&t2_ordinals, &t2}), IsOk());
444   // aggregator1 totals: [32, 11, 15, 4, 2]
445 
446   Tensor t3_ordinals =
447       Tensor::Create(DT_INT64, {4}, CreateTestData<int64_t>({2, 2, 5, 1}))
448           .value();
449   Tensor t3 =
450       Tensor::Create(DT_INT32, {4}, CreateTestData({3, 11, 7, 20})).value();
451   EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
452   // aggregator2 totals: [0, 20, 14, 0, 0, 7]
453 
454   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
455   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
456   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
457 
458   auto result = std::move(aggregator1).Report();
459   EXPECT_THAT(result, IsOk());
460   EXPECT_THAT(result.value().size(), Eq(1));
461   // Verify the resulting tensor.
462   EXPECT_THAT(result.value()[0], IsTensor({6}, {32, 31, 29, 4, 2, 7}));
463   // Also ensure that the resulting tensor is dense.
464   EXPECT_TRUE(result.value()[0].is_dense());
465 }
466 
TEST(GroupingAggregatorTest,Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds)467 TEST(GroupingAggregatorTest,
468      Merge_OtherOutputHasMoreElements_NonzeroDefaultValue_Succeeds) {
469   // Use a MinGroupingAggregator which has a non-zero default value so we can
470   // test that when the output grows, elements are set to the default value.
471   MinGroupingAggregator aggregator1(DT_INT32);
472   MinGroupingAggregator aggregator2(DT_INT32);
473   Tensor t1_ordinals =
474       Tensor::Create(DT_INT64, {2}, CreateTestData<int64_t>({2, 0})).value();
475   Tensor t1 = Tensor::Create(DT_INT32, {2}, CreateTestData({-17, 3})).value();
476   EXPECT_THAT(aggregator1.Accumulate({&t1_ordinals, &t1}), IsOk());
477   // aggregator1 totals: [3, INT_MAX, -17]
478 
479   Tensor t2_ordinals =
480       Tensor::Create(DT_INT64, {6}, CreateTestData<int64_t>({0, 0, 0, 4, 4, 0}))
481           .value();
482   Tensor t2 =
483       Tensor::Create(DT_INT32, {6}, CreateTestData({10, 5, 13, 2, 4, -50}))
484           .value();
485   EXPECT_THAT(aggregator2.Accumulate({&t2_ordinals, &t2}), IsOk());
486   // aggregator2 totals: [-50, INT_MAX, INT_MAX, INT_MAX, 2]
487   Tensor t3_ordinals =
488       Tensor::Create(DT_INT64, {5}, CreateTestData<int64_t>({2, 2, 1, 0, 4}))
489           .value();
490   Tensor t3 =
491       Tensor::Create(DT_INT32, {5}, CreateTestData({33, 11, 7, 6, 3})).value();
492   EXPECT_THAT(aggregator2.Accumulate({&t3_ordinals, &t3}), IsOk());
493   // aggregator2 totals: [-50, 7, 11, INT_MAX, 2]
494 
495   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
496   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
497   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
498 
499   auto result = std::move(aggregator1).Report();
500   EXPECT_THAT(result, IsOk());
501   EXPECT_THAT(result.value().size(), Eq(1));
502   // Verify the resulting tensor.
503   EXPECT_THAT(result.value()[0], IsTensor({5}, {-50, 7, -17, INT_MAX, 2}));
504   // Also ensure that the resulting tensor is dense.
505   EXPECT_TRUE(result.value()[0].is_dense());
506 }
507 
TEST(GroupingAggregatorTest,Aggregate_OrdinalTensorHasIncompatibleDataType)508 TEST(GroupingAggregatorTest, Aggregate_OrdinalTensorHasIncompatibleDataType) {
509   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
510   Tensor ordinal =
511       Tensor::Create(DT_INT32, {}, CreateTestData<int32_t>({0})).value();
512   Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
513   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
514 }
515 
TEST(GroupingAggregatorTest,Aggregate_IncompatibleDataType)516 TEST(GroupingAggregatorTest, Aggregate_IncompatibleDataType) {
517   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
518   Tensor ordinal =
519       Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
520   Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
521   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
522 }
523 
TEST(GroupingAggregatorTest,Aggregate_OrdinalAndValueTensorsHaveIncompatibleShapes)524 TEST(GroupingAggregatorTest,
525      Aggregate_OrdinalAndValueTensorsHaveIncompatibleShapes) {
526   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
527   Tensor ordinal =
528       Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
529   Tensor t = Tensor::Create(DT_INT32, {2}, CreateTestData({0, 1})).value();
530   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
531 }
532 
TEST(GroupingAggregatorTest,Aggregate_MultidimensionalTensorsNotSupported)533 TEST(GroupingAggregatorTest, Aggregate_MultidimensionalTensorsNotSupported) {
534   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
535   Tensor ordinal =
536       Tensor::Create(DT_INT64, {2, 2}, CreateTestData<int64_t>({0, 0, 0, 0}))
537           .value();
538   Tensor t =
539       Tensor::Create(DT_INT32, {2, 2}, CreateTestData({0, 1, 2, 3})).value();
540   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsCode(INVALID_ARGUMENT));
541 }
542 
TEST(GroupingAggregatorTest,Merge_IncompatibleDataType)543 TEST(GroupingAggregatorTest, Merge_IncompatibleDataType) {
544   SumGroupingAggregator<int32_t> aggregator1(DT_INT32);
545   SumGroupingAggregator<float> aggregator2(DT_FLOAT);
546   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
547               IsCode(INVALID_ARGUMENT));
548 }
549 
TEST(GroupingAggregatorTest,FailsAfterBeingConsumed)550 TEST(GroupingAggregatorTest, FailsAfterBeingConsumed) {
551   Tensor ordinal =
552       Tensor::Create(DT_INT64, {}, CreateTestData<int64_t>({0})).value();
553   Tensor t = Tensor::Create(DT_INT32, {}, CreateTestData({0})).value();
554   SumGroupingAggregator<int32_t> aggregator(DT_INT32);
555   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}), IsOk());
556   EXPECT_THAT(std::move(aggregator).Report(), IsOk());
557 
558   // Now the aggregator instance has been consumed and should fail any
559   // further operations.
560   EXPECT_THAT(aggregator.CanReport(), IsFalse());  // NOLINT
561   EXPECT_THAT(std::move(aggregator).Report(),
562               IsCode(FAILED_PRECONDITION));           // NOLINT
563   EXPECT_THAT(aggregator.Accumulate({&ordinal, &t}),  // NOLINT
564               IsCode(FAILED_PRECONDITION));
565   EXPECT_THAT(
566       aggregator.MergeWith(SumGroupingAggregator<int32_t>(DT_INT32)),  // NOLINT
567       IsCode(FAILED_PRECONDITION));
568 
569   // Passing this aggregator as an argument to another MergeWith must fail too.
570   SumGroupingAggregator<int32_t> aggregator2(DT_INT32);
571   EXPECT_THAT(aggregator2.MergeWith(std::move(aggregator)),  // NOLINT
572               IsCode(FAILED_PRECONDITION));
573 }
574 
TEST(GroupingAggregatorTest,TypeCheckFailure)575 TEST(GroupingAggregatorTest, TypeCheckFailure) {
576   EXPECT_DEATH(new SumGroupingAggregator<float>(DT_INT32),
577                "Incompatible dtype");
578 }
579 
580 }  // namespace
581 }  // namespace aggregation
582 }  // namespace fcp
583