1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/aggregation/core/agg_vector_aggregator.h"
18
19 #include <cstdint>
20 #include <utility>
21
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 #include "fcp/aggregation/core/input_tensor_list.h"
25 #include "fcp/aggregation/core/tensor.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/aggregation/testing/test_data.h"
28 #include "fcp/aggregation/testing/testing.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/testing/testing.h"
31
32 #ifndef FCP_NANOLIBC
33 #include "fcp/aggregation/core/tensor.pb.h"
34 #endif
35
36 namespace fcp {
37 namespace aggregation {
38 namespace {
39
40 using testing::Eq;
41 using testing::IsFalse;
42 using testing::IsTrue;
43
44 // A simple Sum Aggregator
45 template <typename T>
46 class SumAggregator final : public AggVectorAggregator<T> {
47 public:
48 using AggVectorAggregator<T>::AggVectorAggregator;
49 using AggVectorAggregator<T>::data;
50
51 private:
AggregateVector(const AggVector<T> & agg_vector)52 void AggregateVector(const AggVector<T>& agg_vector) override {
53 for (auto [i, v] : agg_vector) {
54 data()[i] += v;
55 }
56 }
57 };
58
TEST(AggVectorAggregatorTest,ScalarAggregation_Succeeds)59 TEST(AggVectorAggregatorTest, ScalarAggregation_Succeeds) {
60 SumAggregator<int32_t> aggregator(DT_INT32, {});
61 Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
62 Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
63 Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
64 EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
65 EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
66 EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
67 EXPECT_THAT(aggregator.CanReport(), IsTrue());
68
69 auto result = std::move(aggregator).Report();
70 EXPECT_THAT(result, IsOk());
71 EXPECT_THAT(result.value().size(), Eq(1));
72 // Verify the resulting tensor.
73 EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
74 }
75
TEST(AggVectorAggregatorTest,DenseAggregation_Succeeds)76 TEST(AggVectorAggregatorTest, DenseAggregation_Succeeds) {
77 const TensorShape shape = {4};
78 SumAggregator<int32_t> aggregator(DT_INT32, shape);
79 Tensor t1 =
80 Tensor::Create(DT_INT32, shape, CreateTestData({1, 3, 15, 27})).value();
81 Tensor t2 =
82 Tensor::Create(DT_INT32, shape, CreateTestData({10, 5, 1, 2})).value();
83 Tensor t3 =
84 Tensor::Create(DT_INT32, shape, CreateTestData({3, 11, 7, 20})).value();
85 EXPECT_THAT(aggregator.Accumulate(t1), IsOk());
86 EXPECT_THAT(aggregator.Accumulate(t2), IsOk());
87 EXPECT_THAT(aggregator.Accumulate(t3), IsOk());
88 EXPECT_THAT(aggregator.CanReport(), IsTrue());
89 EXPECT_THAT(aggregator.GetNumInputs(), Eq(3));
90
91 auto result = std::move(aggregator).Report();
92 EXPECT_THAT(result, IsOk());
93 EXPECT_THAT(result.value().size(), Eq(1));
94 // Verify the resulting tensor.
95 EXPECT_THAT(result.value()[0], IsTensor(shape, {14, 19, 23, 49}));
96 // Also ensure that the resulting tensor is dense.
97 EXPECT_TRUE(result.value()[0].is_dense());
98 }
99
TEST(AggVectorAggregationTest,Merge_Succeeds)100 TEST(AggVectorAggregationTest, Merge_Succeeds) {
101 SumAggregator<int32_t> aggregator1(DT_INT32, {});
102 SumAggregator<int32_t> aggregator2(DT_INT32, {});
103 Tensor t1 = Tensor::Create(DT_INT32, {}, CreateTestData({1})).value();
104 Tensor t2 = Tensor::Create(DT_INT32, {}, CreateTestData({2})).value();
105 Tensor t3 = Tensor::Create(DT_INT32, {}, CreateTestData({3})).value();
106 EXPECT_THAT(aggregator1.Accumulate(t1), IsOk());
107 EXPECT_THAT(aggregator2.Accumulate(t2), IsOk());
108 EXPECT_THAT(aggregator2.Accumulate(t3), IsOk());
109
110 EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)), IsOk());
111 EXPECT_THAT(aggregator1.CanReport(), IsTrue());
112 EXPECT_THAT(aggregator1.GetNumInputs(), Eq(3));
113
114 auto result = std::move(aggregator1).Report();
115 EXPECT_THAT(result, IsOk());
116 EXPECT_THAT(result.value().size(), Eq(1));
117 EXPECT_THAT(result.value()[0], IsTensor({}, {6}));
118 }
119
TEST(AggVectorAggregationTest,Aggregate_IncompatibleDataType)120 TEST(AggVectorAggregationTest, Aggregate_IncompatibleDataType) {
121 SumAggregator<int32_t> aggregator(DT_INT32, {});
122 Tensor t = Tensor::Create(DT_FLOAT, {}, CreateTestData<float>({0})).value();
123 EXPECT_THAT(aggregator.Accumulate(t), IsCode(INVALID_ARGUMENT));
124 }
125
TEST(AggVectorAggregationTest,Aggregate_IncompatibleShape)126 TEST(AggVectorAggregationTest, Aggregate_IncompatibleShape) {
127 SumAggregator<int32_t> aggregator(DT_INT32, {});
128 Tensor t = Tensor::Create(DT_INT32, {2, 1}, CreateTestData({0, 1})).value();
129 EXPECT_THAT(aggregator.Accumulate(t), IsCode(INVALID_ARGUMENT));
130 }
131
TEST(AggVectorAggregationTest,Merge_IncompatibleDataType)132 TEST(AggVectorAggregationTest, Merge_IncompatibleDataType) {
133 SumAggregator<int32_t> aggregator1(DT_INT32, {});
134 SumAggregator<float> aggregator2(DT_FLOAT, {});
135 EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
136 IsCode(INVALID_ARGUMENT));
137 }
138
TEST(AggVectorAggregationTest,Merge_IncompatibleShape)139 TEST(AggVectorAggregationTest, Merge_IncompatibleShape) {
140 SumAggregator<int32_t> aggregator1(DT_INT32, {3, 5});
141 SumAggregator<int32_t> aggregator2(DT_INT32, {5, 3});
142 EXPECT_THAT(aggregator1.MergeWith(std::move(aggregator2)),
143 IsCode(INVALID_ARGUMENT));
144 }
145
TEST(AggVectorAggregationTest,FailsAfterBeingConsumed)146 TEST(AggVectorAggregationTest, FailsAfterBeingConsumed) {
147 SumAggregator<int32_t> aggregator(DT_INT32, {});
148 EXPECT_THAT(std::move(aggregator).Report(), IsOk());
149
150 // Now the aggregator instance has been consumed and should fail any
151 // further operations.
152 EXPECT_THAT(aggregator.CanReport(), IsFalse()); // NOLINT
153 EXPECT_THAT(std::move(aggregator).Report(),
154 IsCode(FAILED_PRECONDITION)); // NOLINT
155 EXPECT_THAT(aggregator.Accumulate( // NOLINT
156 Tensor::Create(DT_INT32, {}, CreateTestData({0})).value()),
157 IsCode(FAILED_PRECONDITION));
158 EXPECT_THAT(
159 aggregator.MergeWith(SumAggregator<int32_t>(DT_INT32, {})), // NOLINT
160 IsCode(FAILED_PRECONDITION));
161
162 // Passing this aggregator as an argument to another MergeWith must fail too.
163 SumAggregator<int32_t> aggregator2(DT_INT32, {});
164 EXPECT_THAT(aggregator2.MergeWith(std::move(aggregator)), // NOLINT
165 IsCode(FAILED_PRECONDITION));
166 }
167
TEST(AggVectorAggregatorTest,TypeCheckFailure)168 TEST(AggVectorAggregatorTest, TypeCheckFailure) {
169 EXPECT_DEATH(new SumAggregator<float>(DT_INT32, {}), "Incompatible dtype");
170 }
171
172 } // namespace
173 } // namespace aggregation
174 } // namespace fcp
175