1 /* 2 * Copyright 2023 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_ 17 #define FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_ 18 19 #include <cstddef> 20 #include <cstdint> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 25 #include "fcp/aggregation/core/agg_vector.h" 26 #include "fcp/aggregation/core/datatype.h" 27 #include "fcp/aggregation/core/input_tensor_list.h" 28 #include "fcp/aggregation/core/mutable_vector_data.h" 29 #include "fcp/aggregation/core/tensor.h" 30 #include "fcp/aggregation/core/tensor_aggregator.h" 31 #include "fcp/aggregation/core/tensor_data.h" 32 #include "fcp/aggregation/core/tensor_shape.h" 33 #include "fcp/base/monitoring.h" 34 35 namespace fcp { 36 namespace aggregation { 37 38 // GroupingAggregator class is a specialization of TensorAggregator which 39 // takes in a tensor containing ordinals and a tensor containing values, and 40 // accumulates the values into the output positions indicated by the 41 // corresponding ordinals. 42 // 43 // Currently only 1D input tensors are supported. 44 // 45 // The specific means of accumulating values and producing default values are 46 // left to the subclass. 47 // 48 // The implementation operates on AggVector<T> instances rather than tensors. 49 template <typename T> 50 class OneDimGroupingAggregator : public TensorAggregator { 51 public: 52 // TODO(team): Support accumulating tensors of multiple dimensions. In 53 // that case, the size of all dimensions but one (the dimension corresponding 54 // to the ordinal tensor) should be known in advance and thus this constructor 55 // should take in a shape with a single unknown dimension. OneDimGroupingAggregator(DataType dtype)56 explicit OneDimGroupingAggregator(DataType dtype) 57 : data_vector_(std::make_unique<MutableVectorData<T>>()), num_inputs_(0) { 58 FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype) 59 << "Incompatible dtype"; 60 } 61 MergeWith(TensorAggregator && other)62 Status MergeWith(TensorAggregator&& other) override { 63 FCP_RETURN_IF_ERROR(CheckValid()); 64 OneDimGroupingAggregator<T>* other_ptr = 65 dynamic_cast<OneDimGroupingAggregator<T>*>(&other); 66 if (other_ptr == nullptr) { 67 return FCP_STATUS(INVALID_ARGUMENT) 68 << "GroupingAggregator::MergeOutputTensors: Can only merge with " 69 "another GroupingAggregator operating on the same dtype " 70 << internal::TypeTraits<T>::kDataType; 71 } 72 FCP_RETURN_IF_ERROR((*other_ptr).CheckValid()); 73 int other_num_inputs = other.GetNumInputs(); 74 OutputTensorList output_tensors = std::move(*other_ptr).TakeOutputs(); 75 76 if (output_tensors.size() == 1) { 77 AggVector<T> other_data_vector = output_tensors[0].AsAggVector<T>(); 78 if (other_data_vector.size() > data_vector_->size()) { 79 data_vector_->resize(other_data_vector.size(), GetDefaultValue()); 80 } 81 AggregateVector(other_data_vector); 82 } else { 83 // An empty output is valid and merging it into the current 84 // GroupingAggregator is a no-op. 85 FCP_CHECK(output_tensors.empty()) 86 << "GroupingAggregator::MergeOutputTensors: GroupingAggregator " 87 "should produce at most a single output tensor."; 88 } 89 90 num_inputs_ += other_num_inputs; 91 return FCP_STATUS(OK); 92 } 93 GetNumInputs()94 int GetNumInputs() const override { return num_inputs_; } 95 96 protected: 97 // Provides mutable access to the aggregator data as a vector<T> data()98 inline std::vector<T>& data() { return *data_vector_; } 99 100 // Implementation of the tensor aggregation. 101 // Expects 2 tensors as input: a tensor containing ordinals and a tensor 102 // containing values. 103 // 104 // Accumulates the values into the positions in the output tensor which are 105 // indicated by the corresponding ordinals. AggregateTensors(InputTensorList tensors)106 Status AggregateTensors(InputTensorList tensors) override { 107 FCP_CHECK(tensors.size() == 2) 108 << "GroupingAggregator should operate on 2 input tensors"; 109 110 const Tensor* ordinals = tensors[0]; 111 if (ordinals->dtype() != DT_INT64) { 112 return FCP_STATUS(INVALID_ARGUMENT) 113 << "GroupingAggregator::AggregateTensors: dtype mismatch for " 114 "tensor 0. Expected DT_INT64."; 115 } 116 const Tensor* tensor = tensors[1]; 117 if (tensor->dtype() != internal::TypeTraits<T>::kDataType) { 118 return FCP_STATUS(INVALID_ARGUMENT) 119 << "GroupingAggregator::AggregateTensors: dtype mismatch for " 120 "tensor 1"; 121 } 122 if (ordinals->shape() != tensor->shape()) { 123 return FCP_STATUS(INVALID_ARGUMENT) 124 << "GroupingAggregator::AggregateTensors: tensor shape mismatch. " 125 "Shape of both tensors must be the same."; 126 } 127 int num_dimensions = tensor->shape().dim_sizes().size(); 128 if (num_dimensions > 1) { 129 return FCP_STATUS(INVALID_ARGUMENT) 130 << "GroupingAggregator::AggregateTensors: Only 1 dimensional " 131 "tensors supported. Input tensor has " 132 << num_dimensions << " dimensions."; 133 } 134 if (!ordinals->is_dense() || !tensor->is_dense()) { 135 return FCP_STATUS(INVALID_ARGUMENT) 136 << "GroupingAggregator::AggregateTensors: Only dense tensors are " 137 "supported."; 138 } 139 num_inputs_++; 140 AggVector<T> value_vector = tensor->AsAggVector<T>(); 141 AggVector<int64_t> ordinals_vector = ordinals->AsAggVector<int64_t>(); 142 size_t final_size = data_vector_->size(); 143 for (auto o : ordinals_vector) { 144 if (o.value >= final_size) { 145 final_size = o.value + 1; 146 } 147 } 148 // Resize once outside the loop to avoid quadratic behavior. 149 data_vector_->resize(final_size, GetDefaultValue()); 150 AggregateVectorByOrdinals(ordinals_vector, value_vector); 151 return FCP_STATUS(OK); 152 } 153 CheckValid()154 Status CheckValid() const override { 155 if (data_vector_ == nullptr) { 156 return FCP_STATUS(FAILED_PRECONDITION) 157 << "GroupingAggregator::CheckValid: Output has already been " 158 "consumed."; 159 } 160 return FCP_STATUS(OK); 161 } 162 TakeOutputs()163 OutputTensorList TakeOutputs() && override { 164 OutputTensorList outputs = std::vector<Tensor>(); 165 if (!data_vector_->empty()) { 166 outputs.push_back(Tensor::Create(internal::TypeTraits<T>::kDataType, 167 TensorShape{data_vector_->size()}, 168 std::move(data_vector_)) 169 .value()); 170 } 171 data_vector_ = nullptr; 172 return outputs; 173 } 174 175 // Delegates AggVector aggregation by ordinal to a derived class. 176 // 177 // The size of the vector returned by data() must be greater than the largest 178 // ordinal in this vector. 179 // 180 // To avoid making a virtual function call per value in the tensor, the whole 181 // vector is passed to the subclass for aggregation, which provides better 182 // performance but comes at the cost of duplicated code between subclasses for 183 // iterating over the vectors. 184 virtual void AggregateVectorByOrdinals( 185 const AggVector<int64_t>& ordinals_vector, 186 const AggVector<T>& value_vector) = 0; 187 188 // Delegates AggVector aggregation to a derived class. 189 // 190 // This vector must be the same size as the vector returned by data(). 191 // 192 // To avoid making a virtual function call per value in the tensor, the whole 193 // vector is passed to the subclass for aggregation, which provides better 194 // performance but comes at the cost of duplicated code between subclasses for 195 // iterating over the vectors. 196 virtual void AggregateVector(const AggVector<T>& agg_vector) = 0; 197 198 // Delegates initialization of previously unseen ordinals to a derived class. 199 virtual T GetDefaultValue() = 0; 200 201 private: 202 std::unique_ptr<MutableVectorData<T>> data_vector_; 203 int num_inputs_; 204 }; 205 206 } // namespace aggregation 207 } // namespace fcp 208 209 #endif // FCP_AGGREGATION_CORE_ONE_DIM_GROUPING_AGGREGATOR_H_ 210