xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/one_dim_grouping_aggregator.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_
17 #define FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_
18 
19 #include <cstddef>
20 #include <cstdint>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "fcp/aggregation/core/agg_vector.h"
26 #include "fcp/aggregation/core/datatype.h"
27 #include "fcp/aggregation/core/input_tensor_list.h"
28 #include "fcp/aggregation/core/mutable_vector_data.h"
29 #include "fcp/aggregation/core/tensor.h"
30 #include "fcp/aggregation/core/tensor_aggregator.h"
31 #include "fcp/aggregation/core/tensor_data.h"
32 #include "fcp/aggregation/core/tensor_shape.h"
33 #include "fcp/base/monitoring.h"
34 
35 namespace fcp {
36 namespace aggregation {
37 
38 // GroupingAggregator class is a specialization of TensorAggregator which
39 // takes in a tensor containing ordinals and a tensor containing values, and
40 // accumulates the values into the output positions indicated by the
41 // corresponding ordinals.
42 //
43 // Currently only 1D input tensors are supported.
44 //
45 // The specific means of accumulating values and producing default values are
46 // left to the subclass.
47 //
48 // The implementation operates on AggVector<T> instances rather than tensors.
49 template <typename T>
50 class OneDimGroupingAggregator : public TensorAggregator {
51  public:
52   // TODO(team): Support accumulating tensors of multiple dimensions. In
53   // that case, the size of all dimensions but one (the dimension corresponding
54   // to the ordinal tensor) should be known in advance and thus this constructor
55   // should take in a shape with a single unknown dimension.
OneDimGroupingAggregator(DataType dtype)56   explicit OneDimGroupingAggregator(DataType dtype)
57       : data_vector_(std::make_unique<MutableVectorData<T>>()), num_inputs_(0) {
58     FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype)
59         << "Incompatible dtype";
60   }
61 
MergeWith(TensorAggregator && other)62   Status MergeWith(TensorAggregator&& other) override {
63     FCP_RETURN_IF_ERROR(CheckValid());
64     OneDimGroupingAggregator<T>* other_ptr =
65         dynamic_cast<OneDimGroupingAggregator<T>*>(&other);
66     if (other_ptr == nullptr) {
67       return FCP_STATUS(INVALID_ARGUMENT)
68              << "GroupingAggregator::MergeOutputTensors: Can only merge with "
69                 "another GroupingAggregator operating on the same dtype "
70              << internal::TypeTraits<T>::kDataType;
71     }
72     FCP_RETURN_IF_ERROR((*other_ptr).CheckValid());
73     int other_num_inputs = other.GetNumInputs();
74     OutputTensorList output_tensors = std::move(*other_ptr).TakeOutputs();
75 
76     if (output_tensors.size() == 1) {
77       AggVector<T> other_data_vector = output_tensors[0].AsAggVector<T>();
78       if (other_data_vector.size() > data_vector_->size()) {
79         data_vector_->resize(other_data_vector.size(), GetDefaultValue());
80       }
81       AggregateVector(other_data_vector);
82     } else {
83       // An empty output is valid and merging it into the current
84       // GroupingAggregator is a no-op.
85       FCP_CHECK(output_tensors.empty())
86           << "GroupingAggregator::MergeOutputTensors: GroupingAggregator "
87              "should produce at most a single output tensor.";
88     }
89 
90     num_inputs_ += other_num_inputs;
91     return FCP_STATUS(OK);
92   }
93 
GetNumInputs()94   int GetNumInputs() const override { return num_inputs_; }
95 
96  protected:
97   // Provides mutable access to the aggregator data as a vector<T>
data()98   inline std::vector<T>& data() { return *data_vector_; }
99 
100   // Implementation of the tensor aggregation.
101   // Expects 2 tensors as input: a tensor containing ordinals and a tensor
102   // containing values.
103   //
104   // Accumulates the values into the positions in the output tensor which are
105   // indicated by the corresponding ordinals.
AggregateTensors(InputTensorList tensors)106   Status AggregateTensors(InputTensorList tensors) override {
107     FCP_CHECK(tensors.size() == 2)
108         << "GroupingAggregator should operate on 2 input tensors";
109 
110     const Tensor* ordinals = tensors[0];
111     if (ordinals->dtype() != DT_INT64) {
112       return FCP_STATUS(INVALID_ARGUMENT)
113              << "GroupingAggregator::AggregateTensors: dtype mismatch for "
114                 "tensor 0. Expected DT_INT64.";
115     }
116     const Tensor* tensor = tensors[1];
117     if (tensor->dtype() != internal::TypeTraits<T>::kDataType) {
118       return FCP_STATUS(INVALID_ARGUMENT)
119              << "GroupingAggregator::AggregateTensors: dtype mismatch for "
120                 "tensor 1";
121     }
122     if (ordinals->shape() != tensor->shape()) {
123       return FCP_STATUS(INVALID_ARGUMENT)
124              << "GroupingAggregator::AggregateTensors: tensor shape mismatch. "
125                 "Shape of both tensors must be the same.";
126     }
127     int num_dimensions = tensor->shape().dim_sizes().size();
128     if (num_dimensions > 1) {
129       return FCP_STATUS(INVALID_ARGUMENT)
130              << "GroupingAggregator::AggregateTensors: Only 1 dimensional "
131                 "tensors supported. Input tensor has "
132              << num_dimensions << " dimensions.";
133     }
134     if (!ordinals->is_dense() || !tensor->is_dense()) {
135       return FCP_STATUS(INVALID_ARGUMENT)
136              << "GroupingAggregator::AggregateTensors: Only dense tensors are "
137                 "supported.";
138     }
139     num_inputs_++;
140     AggVector<T> value_vector = tensor->AsAggVector<T>();
141     AggVector<int64_t> ordinals_vector = ordinals->AsAggVector<int64_t>();
142     size_t final_size = data_vector_->size();
143     for (auto o : ordinals_vector) {
144       if (o.value >= final_size) {
145         final_size = o.value + 1;
146       }
147     }
148     // Resize once outside the loop to avoid quadratic behavior.
149     data_vector_->resize(final_size, GetDefaultValue());
150     AggregateVectorByOrdinals(ordinals_vector, value_vector);
151     return FCP_STATUS(OK);
152   }
153 
CheckValid()154   Status CheckValid() const override {
155     if (data_vector_ == nullptr) {
156       return FCP_STATUS(FAILED_PRECONDITION)
157              << "GroupingAggregator::CheckValid: Output has already been "
158                 "consumed.";
159     }
160     return FCP_STATUS(OK);
161   }
162 
TakeOutputs()163   OutputTensorList TakeOutputs() && override {
164     OutputTensorList outputs = std::vector<Tensor>();
165     if (!data_vector_->empty()) {
166       outputs.push_back(Tensor::Create(internal::TypeTraits<T>::kDataType,
167                                        TensorShape{data_vector_->size()},
168                                        std::move(data_vector_))
169                             .value());
170     }
171     data_vector_ = nullptr;
172     return outputs;
173   }
174 
175   // Delegates AggVector aggregation by ordinal to a derived class.
176   //
177   // The size of the vector returned by data() must be greater than the largest
178   // ordinal in this vector.
179   //
180   // To avoid making a virtual function call per value in the tensor, the whole
181   // vector is passed to the subclass for aggregation, which provides better
182   // performance but comes at the cost of duplicated code between subclasses for
183   // iterating over the vectors.
184   virtual void AggregateVectorByOrdinals(
185       const AggVector<int64_t>& ordinals_vector,
186       const AggVector<T>& value_vector) = 0;
187 
188   // Delegates AggVector aggregation to a derived class.
189   //
190   // This vector must be the same size as the vector returned by data().
191   //
192   // To avoid making a virtual function call per value in the tensor, the whole
193   // vector is passed to the subclass for aggregation, which provides better
194   // performance but comes at the cost of duplicated code between subclasses for
195   // iterating over the vectors.
196   virtual void AggregateVector(const AggVector<T>& agg_vector) = 0;
197 
198   // Delegates initialization of previously unseen ordinals to a derived class.
199   virtual T GetDefaultValue() = 0;
200 
201  private:
202   std::unique_ptr<MutableVectorData<T>> data_vector_;
203   int num_inputs_;
204 };
205 
206 }  // namespace aggregation
207 }  // namespace fcp
208 
209 #endif  // FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_
210