1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_ 18 #define FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_ 19 20 #include <cstdint> 21 #include <memory> 22 #include <utility> 23 #include <vector> 24 25 #include "fcp/aggregation/core/agg_vector.h" 26 #include "fcp/aggregation/core/datatype.h" 27 #include "fcp/aggregation/core/input_tensor_list.h" 28 #include "fcp/aggregation/core/mutable_vector_data.h" 29 #include "fcp/aggregation/core/tensor.h" 30 #include "fcp/aggregation/core/tensor_aggregator.h" 31 #include "fcp/aggregation/core/tensor_data.h" 32 #include "fcp/aggregation/core/tensor_shape.h" 33 #include "fcp/base/monitoring.h" 34 35 namespace fcp { 36 namespace aggregation { 37 38 // AggVectorAggregator class is a specialization of TensorAggregator which 39 // operates on AggVector<T> instances rather than tensors. 40 template <typename T> 41 class AggVectorAggregator : public TensorAggregator { 42 public: AggVectorAggregator(DataType dtype,TensorShape shape)43 AggVectorAggregator(DataType dtype, TensorShape shape) 44 : AggVectorAggregator(dtype, shape, 45 new MutableVectorData<T>(shape.NumElements())) {} 46 47 // Provides mutable access to the aggregator data as a vector<T> data()48 inline std::vector<T>& data() { return data_vector_; } 49 GetNumInputs()50 int GetNumInputs() const override { return num_inputs_; } 51 MergeWith(TensorAggregator && other)52 Status MergeWith(TensorAggregator&& other) override { 53 FCP_RETURN_IF_ERROR(CheckValid()); 54 FCP_ASSIGN_OR_RETURN(AggVectorAggregator<T> * other_ptr, CastOther(other)); 55 FCP_RETURN_IF_ERROR((*other_ptr).CheckValid()); 56 int64_t other_num_inputs = other.GetNumInputs(); 57 OutputTensorList output_tensors = std::move(*other_ptr).TakeOutputs(); 58 FCP_CHECK(output_tensors.size() == 1) 59 << "AggVectorAggregator::MergeOutputTensors: AggVectorAggregator " 60 "should produce a single output tensor"; 61 const Tensor& output = output_tensors[0]; 62 if (output.shape() != result_tensor_.shape()) { 63 return FCP_STATUS(INVALID_ARGUMENT) 64 << "AggVectorAggregator::MergeOutputTensors: tensor shape " 65 "mismatch"; 66 } 67 // Delegate the actual aggregation to the specific aggregation 68 // intrinsic implementation. 69 AggregateVector(output.AsAggVector<T>()); 70 num_inputs_ += other_num_inputs; 71 return FCP_STATUS(OK); 72 } 73 74 protected: 75 // Implementation of the tensor aggregation. AggregateTensors(InputTensorList tensors)76 Status AggregateTensors(InputTensorList tensors) override { 77 FCP_CHECK(tensors.size() == 1) 78 << "AggVectorAggregator should operate on a single input tensor"; 79 80 const Tensor* tensor = tensors[0]; 81 if (tensor->dtype() != internal::TypeTraits<T>::kDataType) { 82 return FCP_STATUS(INVALID_ARGUMENT) 83 << "AggVectorAggregator::AggregateTensors: dtype mismatch"; 84 } 85 if (tensor->shape() != result_tensor_.shape()) { 86 return FCP_STATUS(INVALID_ARGUMENT) 87 << "AggVectorAggregator::AggregateTensors: tensor shape mismatch"; 88 } 89 // Delegate the actual aggregation to the specific aggregation 90 // intrinsic implementation. 91 AggregateVector(tensor->AsAggVector<T>()); 92 num_inputs_++; 93 return FCP_STATUS(OK); 94 } 95 CheckValid()96 Status CheckValid() const override { return result_tensor_.CheckValid(); } 97 TakeOutputs()98 OutputTensorList TakeOutputs() && override { 99 OutputTensorList outputs = std::vector<Tensor>(); 100 outputs.push_back(std::move(result_tensor_)); 101 return outputs; 102 } 103 104 // Delegates AggVector aggregation to a derived class. 105 virtual void AggregateVector(const AggVector<T>& agg_vector) = 0; 106 107 private: AggVectorAggregator(DataType dtype,TensorShape shape,MutableVectorData<T> * data)108 AggVectorAggregator(DataType dtype, TensorShape shape, 109 MutableVectorData<T>* data) 110 : result_tensor_( 111 Tensor::Create(dtype, shape, std::unique_ptr<TensorData>(data)) 112 .value()), 113 data_vector_(*data), 114 num_inputs_(0) { 115 FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype) 116 << "Incompatible dtype"; 117 } 118 CastOther(TensorAggregator & other)119 StatusOr<AggVectorAggregator<T>*> CastOther(TensorAggregator& other) { 120 #ifndef FCP_NANOLIBC 121 AggVectorAggregator<T>* other_ptr = 122 dynamic_cast<AggVectorAggregator<T>*>(&other); 123 if (other_ptr == nullptr) { 124 return FCP_STATUS(INVALID_ARGUMENT) 125 << "AggVectorAggregator::MergeOutputTensors: Can only merge with" 126 << "another AggVectorAggregator operating on the same dtype " 127 << internal::TypeTraits<T>::kDataType; 128 } 129 return other_ptr; 130 #else /* FCP_NANOLIBC */ 131 // When compiling in nanolibc we do not have access to runtime type 132 // information or std::type_traits. Thus we cannot use dynamic cast and use 133 // static_cast instead. 134 // This means we are relying on the caller to always call the MergeWith 135 // method on two TensorAggregators of the same underlying type, or the 136 // program will have undefined behavior due to a static_cast to the wrong 137 // type. 138 return static_cast<AggVectorAggregator<T>*>(&other); 139 #endif 140 } 141 142 Tensor result_tensor_; 143 std::vector<T>& data_vector_; 144 int num_inputs_; 145 }; 146 147 } // namespace aggregation 148 } // namespace fcp 149 150 #endif // FCP_AGGREGATION_CORE_AGG_VECTOR_AGGREGATOR_H_ 151