1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <memory>
18 #include <string>
19 #include <utility>
20
21 #include "fcp/aggregation/core/agg_vector_aggregator.h"
22 #include "fcp/aggregation/core/datatype.h"
23 #include "fcp/aggregation/core/tensor_aggregator_factory.h"
24 #include "fcp/aggregation/core/tensor_aggregator_registry.h"
25 #include "fcp/aggregation/core/tensor_shape.h"
26 #include "fcp/base/monitoring.h"
27
28 namespace fcp {
29 namespace aggregation {
30
31 // Implementation of a generic sum aggregator.
32 template <typename T>
33 class FederatedSum final : public AggVectorAggregator<T> {
34 public:
35 using AggVectorAggregator<T>::AggVectorAggregator;
36 using AggVectorAggregator<T>::data;
37
38 private:
AggregateVector(const AggVector<T> & agg_vector)39 void AggregateVector(const AggVector<T>& agg_vector) override {
40 for (auto v : agg_vector) {
41 data()[v.index] += v.value;
42 }
43 }
44 };
45
46 template <typename T>
CreateFederatedSum(DataType dtype,TensorShape shape)47 StatusOr<std::unique_ptr<TensorAggregator>> CreateFederatedSum(
48 DataType dtype, TensorShape shape) {
49 return std::unique_ptr<TensorAggregator>(new FederatedSum<T>(dtype, shape));
50 }
51
52 // Not supported for DT_STRING
53 template <>
CreateFederatedSum(DataType dtype,TensorShape shape)54 StatusOr<std::unique_ptr<TensorAggregator>> CreateFederatedSum<string_view>(
55 DataType dtype, TensorShape shape) {
56 return FCP_STATUS(INVALID_ARGUMENT)
57 << "FederatedSum isn't supported for DT_STRING datatype.";
58 }
59
60 // Factory class for the FederatedSum.
61 class FederatedSumFactory final : public TensorAggregatorFactory {
62 public:
63 FederatedSumFactory() = default;
64
65 // FederatedSumFactory isn't copyable or moveable.
66 FederatedSumFactory(const FederatedSumFactory&) = delete;
67 FederatedSumFactory& operator=(const FederatedSumFactory&) = delete;
68
Create(DataType dtype,TensorShape shape) const69 StatusOr<std::unique_ptr<TensorAggregator>> Create(
70 DataType dtype, TensorShape shape) const override {
71 StatusOr<std::unique_ptr<TensorAggregator>> aggregator;
72 DTYPE_CASES(dtype, T,
73 aggregator = CreateFederatedSum<T>(dtype, std::move(shape)));
74 return aggregator;
75 }
76 };
77
78 // TODO(team): Revise the registration mechanism below.
79 #ifdef FCP_BAREMETAL
RegisterFederatedSum()80 extern "C" void RegisterFederatedSum() {
81 RegisterAggregatorFactory("federated_sum", new FederatedSumFactory());
82 }
83 #else
84 REGISTER_AGGREGATOR_FACTORY("federated_sum", FederatedSumFactory);
85 #endif
86
87 } // namespace aggregation
88 } // namespace fcp
89