xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/agg_vector_aggregator.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
18 #define FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
19 
20 #include <cstdint>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "fcp/aggregation/core/agg_vector.h"
26 #include "fcp/aggregation/core/datatype.h"
27 #include "fcp/aggregation/core/input_tensor_list.h"
28 #include "fcp/aggregation/core/mutable_vector_data.h"
29 #include "fcp/aggregation/core/tensor.h"
30 #include "fcp/aggregation/core/tensor_aggregator.h"
31 #include "fcp/aggregation/core/tensor_data.h"
32 #include "fcp/aggregation/core/tensor_shape.h"
33 #include "fcp/base/monitoring.h"
34 
35 namespace fcp {
36 namespace aggregation {
37 
38 // AggVectorAggregator class is a specialization of TensorAggregator which
39 // operates on AggVector<T> instances rather than tensors.
40 template <typename T>
41 class AggVectorAggregator : public TensorAggregator {
42  public:
AggVectorAggregator(DataType dtype,TensorShape shape)43   AggVectorAggregator(DataType dtype, TensorShape shape)
44       : AggVectorAggregator(dtype, shape,
45                             new MutableVectorData<T>(shape.NumElements())) {}
46 
47   // Provides mutable access to the aggregator data as a vector<T>
data()48   inline std::vector<T>& data() { return data_vector_; }
49 
GetNumInputs()50   int GetNumInputs() const override { return num_inputs_; }
51 
MergeWith(TensorAggregator && other)52   Status MergeWith(TensorAggregator&& other) override {
53     FCP_RETURN_IF_ERROR(CheckValid());
54     FCP_ASSIGN_OR_RETURN(AggVectorAggregator<T> * other_ptr, CastOther(other));
55     FCP_RETURN_IF_ERROR((*other_ptr).CheckValid());
56     int64_t other_num_inputs = other.GetNumInputs();
57     OutputTensorList output_tensors = std::move(*other_ptr).TakeOutputs();
58     FCP_CHECK(output_tensors.size() == 1)
59         << "AggVectorAggregator::MergeOutputTensors: AggVectorAggregator "
60            "should produce a single output tensor";
61     const Tensor& output = output_tensors[0];
62     if (output.shape() != result_tensor_.shape()) {
63       return FCP_STATUS(INVALID_ARGUMENT)
64              << "AggVectorAggregator::MergeOutputTensors: tensor shape "
65                 "mismatch";
66     }
67     // Delegate the actual aggregation to the specific aggregation
68     // intrinsic implementation.
69     AggregateVector(output.AsAggVector<T>());
70     num_inputs_ += other_num_inputs;
71     return FCP_STATUS(OK);
72   }
73 
74  protected:
75   // Implementation of the tensor aggregation.
AggregateTensors(InputTensorList tensors)76   Status AggregateTensors(InputTensorList tensors) override {
77     FCP_CHECK(tensors.size() == 1)
78         << "AggVectorAggregator should operate on a single input tensor";
79 
80     const Tensor* tensor = tensors[0];
81     if (tensor->dtype() != internal::TypeTraits<T>::kDataType) {
82       return FCP_STATUS(INVALID_ARGUMENT)
83              << "AggVectorAggregator::AggregateTensors: dtype mismatch";
84     }
85     if (tensor->shape() != result_tensor_.shape()) {
86       return FCP_STATUS(INVALID_ARGUMENT)
87              << "AggVectorAggregator::AggregateTensors: tensor shape mismatch";
88     }
89     // Delegate the actual aggregation to the specific aggregation
90     // intrinsic implementation.
91     AggregateVector(tensor->AsAggVector<T>());
92     num_inputs_++;
93     return FCP_STATUS(OK);
94   }
95 
CheckValid()96   Status CheckValid() const override { return result_tensor_.CheckValid(); }
97 
TakeOutputs()98   OutputTensorList TakeOutputs() && override {
99     OutputTensorList outputs = std::vector<Tensor>();
100     outputs.push_back(std::move(result_tensor_));
101     return outputs;
102   }
103 
104   // Delegates AggVector aggregation to a derived class.
105   virtual void AggregateVector(const AggVector<T>& agg_vector) = 0;
106 
107  private:
AggVectorAggregator(DataType dtype,TensorShape shape,MutableVectorData<T> * data)108   AggVectorAggregator(DataType dtype, TensorShape shape,
109                       MutableVectorData<T>* data)
110       : result_tensor_(
111             Tensor::Create(dtype, shape, std::unique_ptr<TensorData>(data))
112                 .value()),
113         data_vector_(*data),
114         num_inputs_(0) {
115     FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype)
116         << "Incompatible dtype";
117   }
118 
CastOther(TensorAggregator & other)119   StatusOr<AggVectorAggregator<T>*> CastOther(TensorAggregator& other) {
120 #ifndef FCP_NANOLIBC
121     AggVectorAggregator<T>* other_ptr =
122         dynamic_cast<AggVectorAggregator<T>*>(&other);
123     if (other_ptr == nullptr) {
124       return FCP_STATUS(INVALID_ARGUMENT)
125              << "AggVectorAggregator::MergeOutputTensors: Can only merge with"
126              << "another AggVectorAggregator operating on the same dtype "
127              << internal::TypeTraits<T>::kDataType;
128     }
129     return other_ptr;
130 #else /* FCP_NANOLIBC */
131     // When compiling in nanolibc we do not have access to runtime type
132     // information or std::type_traits. Thus we cannot use dynamic cast and use
133     // static_cast instead.
134     // This means we are relying on the caller to always call the MergeWith
135     // method on two TensorAggregators of the same underlying type, or the
136     // program will have undefined behavior due to a static_cast to the wrong
137     // type.
138     return static_cast<AggVectorAggregator<T>*>(&other);
139 #endif
140   }
141 
142   Tensor result_tensor_;
143   std::vector<T>& data_vector_;
144   int num_inputs_;
145 };
146 
147 }  // namespace aggregation
148 }  // namespace fcp
149 
150 #endif  // FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_
151