xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/agg_vector_aggregator_test.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/aggregation/core/agg_vector_aggregator.h"
18 
19 #include <cstdint>
20 #include <utility>
21 
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "fcp/aggregation/core/input_tensor_list.h"
25 #include "fcp/aggregation/core/tensor.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/aggregation/testing/test_data.h"
28 #include "fcp/aggregation/testing/testing.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/testing/testing.h"
31 
32 #ifndef FCP_NANOLIBC
33 #include "fcp/aggregation/core/tensor.pb.h"
34 #endif
35 
36 namespace fcp {
37 namespace aggregation {
38 namespace {
39 
40 using testing::Eq;
41 using testing::IsFalse;
42 using testing::IsTrue;
43 
44 // A simple Sum Aggregator
45 template <typename T>
46 class SumAggregator final : public AggVectorAggregator<T> {
47  public:
48   using AggVectorAggregator<T>::AggVectorAggregator;
49   using AggVectorAggregator<T>::data;
50 
51  private:
AggregateVector(const AggVector<T> & agg_vector)52   void AggregateVector(const AggVector<T>& agg_vector) override {
53     for (auto [i, v] : agg_vector) {
54       data()[i] += v;
55     }
56   }
57 };
58 
TEST(AggVectorAggregatorTest,ScalarAggregation_Succeeds)59 TEST(AggVectorAggregatorTest, ScalarAggregation_Succeeds) {
60   SumAggregator<int32_t> aggregator(DT_INT32, {});
61   Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
62   Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
63   Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
64   EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
65   EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
66   EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
67   EXPECT_THAT(aggregator.CanReport(), IsTrue());
68 
69   auto result = std::move(aggregator).Report();
70   EXPECT_THAT(result, IsOk());
71   EXPECT_THAT(result.value().size(), Eq(1));
72   // Verify the resulting tensor.
73   EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
74 }
75 
TEST(AggVectorAggregatorTest,DenseAggregation_Succeeds)76 TEST(AggVectorAggregatorTest, DenseAggregation_Succeeds) {
77   const TensorShape shape = {4};
78   SumAggregator<int32_t> aggregator(DT_INT32, shape);
79   Tensor t1 =
80       Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
81   Tensor t2 =
82       Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
83   Tensor t3 =
84       Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
85   EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
86   EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
87   EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
88   EXPECT_THAT(aggregator.CanReport(), IsTrue());
89   EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
90 
91   auto result = std::move(aggregator).Report();
92   EXPECT_THAT(result, IsOk());
93   EXPECT_THAT(result.value().size(), Eq(1));
94   // Verify the resulting tensor.
95   EXPECT_THAT(result.value()[0], IsTensor(shape, {14, 19, 23, 49}));
96   // Also ensure that the resulting tensor is dense.
97   EXPECT_TRUE(result.value()[0].is_dense());
98 }
99 
TEST(AggVectorAggregationTest,Merge_Succeeds)100 TEST(AggVectorAggregationTest, Merge_Succeeds) {
101   SumAggregator<int32_t> aggregator1(DT_INT32, {});
102   SumAggregator<int32_t> aggregator2(DT_INT32, {});
103   Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
104   Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
105   Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
106   EXPECT_THAT(aggregator1.Accumulate(t1), IsOk());
107   EXPECT_THAT(aggregator2.Accumulate(t2), IsOk());
108   EXPECT_THAT(aggregator2.Accumulate(t3), IsOk());
109 
110   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
111   EXPECT_THAT(aggregator1.CanReport(), IsTrue());
112   EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
113 
114   auto result = std::move(aggregator1).Report();
115   EXPECT_THAT(result, IsOk());
116   EXPECT_THAT(result.value().size(), Eq(1));
117   EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
118 }
119 
TEST(AggVectorAggregationTest,Aggregate_IncompatibleDataType)120 TEST(AggVectorAggregationTest, Aggregate_IncompatibleDataType) {
121   SumAggregator<int32_t> aggregator(DT_INT32, {});
122   Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
123   EXPECT_THAT(aggregator.Accumulate(t), IsCode(INVALID_ARGUMENT));
124 }
125 
TEST(AggVectorAggregationTest,Aggregate_IncompatibleShape)126 TEST(AggVectorAggregationTest, Aggregate_IncompatibleShape) {
127   SumAggregator<int32_t> aggregator(DT_INT32, {});
128   Tensor t = Tensor::Create(DT_INT32, {2, 1}, CreateTestData({0, 1})).value();
129   EXPECT_THAT(aggregator.Accumulate(t), IsCode(INVALID_ARGUMENT));
130 }
131 
TEST(AggVectorAggregationTest,Merge_IncompatibleDataType)132 TEST(AggVectorAggregationTest, Merge_IncompatibleDataType) {
133   SumAggregator<int32_t> aggregator1(DT_INT32, {});
134   SumAggregator<float> aggregator2(DT_FLOAT, {});
135   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
136               IsCode(INVALID_ARGUMENT));
137 }
138 
TEST(AggVectorAggregationTest,Merge_IncompatibleShape)139 TEST(AggVectorAggregationTest, Merge_IncompatibleShape) {
140   SumAggregator<int32_t> aggregator1(DT_INT32, {3, 5});
141   SumAggregator<int32_t> aggregator2(DT_INT32, {5, 3});
142   EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
143               IsCode(INVALID_ARGUMENT));
144 }
145 
TEST(AggVectorAggregationTest,FailsAfterBeingConsumed)146 TEST(AggVectorAggregationTest, FailsAfterBeingConsumed) {
147   SumAggregator<int32_t> aggregator(DT_INT32, {});
148   EXPECT_THAT(std::move(aggregator).Report(), IsOk());
149 
150   // Now the aggregator instance has been consumed and should fail any
151   // further operations.
152   EXPECT_THAT(aggregator.CanReport(), IsFalse());  // NOLINT
153   EXPECT_THAT(std::move(aggregator).Report(),
154               IsCode(FAILED_PRECONDITION));  // NOLINT
155   EXPECT_THAT(aggregator.Accumulate(         // NOLINT
156                   Tensor::Create(DT_INT32, {}, CreateTestData({0})).value()),
157               IsCode(FAILED_PRECONDITION));
158   EXPECT_THAT(
159       aggregator.MergeWith(SumAggregator<int32_t>(DT_INT32, {})),  // NOLINT
160       IsCode(FAILED_PRECONDITION));
161 
162   // Passing this aggregator as an argument to another MergeWith must fail too.
163   SumAggregator<int32_t> aggregator2(DT_INT32, {});
164   EXPECT_THAT(aggregator2.MergeWith(std::move(aggregator)),  // NOLINT
165               IsCode(FAILED_PRECONDITION));
166 }
167 
TEST(AggVectorAggregatorTest,TypeCheckFailure)168 TEST(AggVectorAggregatorTest, TypeCheckFailure) {
169   EXPECT_DEATH(new SumAggregator<float>(DT_INT32, {}), "Incompatible dtype");
170 }
171 
172 }  // namespace
173 }  // namespace aggregation
174 }  // namespace fcp
175