xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/composite_key_combiner.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/aggregation/core/composite_key_combiner.h"
17 
18 #include <cstdint>
19 #include <cstring>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 
27 #include "fcp/aggregation/core/datatype.h"
28 #include "fcp/aggregation/core/input_tensor_list.h"
29 #include "fcp/aggregation/core/mutable_vector_data.h"
30 #include "fcp/aggregation/core/tensor.h"
31 #include "fcp/aggregation/core/tensor.pb.h"
32 #include "fcp/aggregation/core/tensor_shape.h"
33 #include "fcp/aggregation/core/vector_string_data.h"
34 #include "fcp/base/monitoring.h"
35 
36 namespace fcp {
37 namespace aggregation {
38 
39 namespace {
40 
41 template <typename T>
CheckDataTypeSupported()42 bool CheckDataTypeSupported() {
43   return sizeof(T) <= sizeof(uint64_t);
44 }
45 
46 template <>
CheckDataTypeSupported()47 bool CheckDataTypeSupported<string_view>() {
48   // We will store the representation of a pointer to the string as an integer,
49   // so ensure the size of a pointer is less than or equal to the size of a
50   // 64-bit integer.
51   return sizeof(intptr_t) == sizeof(uint64_t);
52 }
53 
54 // Copies the bytes pointed to by source_ptr to the destination pointed to by
55 // dest_ptr and advances source_ptr to the next T.
56 //
57 // The number of bytes copied will be the size of the type T.
58 //
59 // It is the responsibility of the caller to ensure that source_ptr is only used
60 // in subsequent code if it still points to a valid T after being incremented.
61 template <typename T>
CopyToDest(const void * & source_ptr,uint64_t * dest_ptr,std::unordered_set<std::string> & intern_pool)62 void CopyToDest(const void*& source_ptr, uint64_t* dest_ptr,
63                 std::unordered_set<std::string>& intern_pool) {
64   auto typed_source_ptr = static_cast<const T*>(source_ptr);
65   // Cast the pointer to the destination storage to a pointer to type T and set
66   // the value it points to to the source value. This allows us to copy the
67   // number of bytes in T to the destination storage even if T is smaller than
68   // uint64_t, without copying extra bytes.
69   // It would be problematic if T is larger than uint64_t, but the Create method
70   // validated that this was not the case.
71   T* typed_dest_ptr = reinterpret_cast<T*>(dest_ptr);
72   *typed_dest_ptr = *typed_source_ptr;
73   // Set source_ptr to point to the next T assuming that it points to
74   // an array of T.
75   source_ptr = static_cast<const void*>(++typed_source_ptr);
76 }
77 
78 // Specialization of CopyToDest for DT_STRING data type that interns the
79 // string_view pointed to by value_ptr. The address of the string in the
80 // intern pool is then converted to a 64 bit integer and copied to the
81 // destination pointed to by dest_ptr. Finally source_ptr is incremented to the
82 // next string_view.
83 //
84 // It is the responsibility of the caller to ensure that source_ptr is only used
85 // in subsequent code if it still points to a valid string_view after being
86 // incremented.
87 template <>
CopyToDest(const void * & source_ptr,uint64_t * dest_ptr,std::unordered_set<std::string> & intern_pool)88 void CopyToDest<string_view>(const void*& source_ptr, uint64_t* dest_ptr,
89                              std::unordered_set<std::string>& intern_pool) {
90   auto string_view_ptr = static_cast<const string_view*>(source_ptr);
91   // Insert the string into the intern pool if it does not already exist. This
92   // makes a copy of the string so that the intern pool owns the storage.
93   auto it = intern_pool.emplace(*string_view_ptr).first;
94   // The iterator of an unordered set may be invalidated by inserting more
95   // elements, but the pointer to the underlying element is guaranteed to be
96   // stable. https://en.cppreference.com/w/cpp/container/unordered_set
97   // Thus, get the address of the string after dereferencing the iterator.
98   const std::string* interned_string_ptr = &*it;
99   // The stable address of the string can be interpreted as a 64-bit integer.
100   intptr_t ptr_int = reinterpret_cast<intptr_t>(interned_string_ptr);
101   // Set the destination storage to the integer representation of the string
102   // address.
103   *dest_ptr = static_cast<uint64_t>(ptr_int);
104   // Set the source_ptr to point to the next string_view assuming that it points
105   // to an array of string_view.
106   source_ptr = static_cast<const void*>(++string_view_ptr);
107 }
108 
109 // Given a vector of uint64_t pointers, where the data pointed to can be safely
110 // interpreted as type T, returns a Tensor of underlying data type
111 // corresponding to T and the same length as the input vector. Each element of
112 // the tensor is created by interpreting the data pointed to by the uint64_t
113 // pointer at that index as type T.
114 template <typename T>
GetTensorForType(const std::vector<const uint64_t * > & key_iters)115 StatusOr<Tensor> GetTensorForType(
116     const std::vector<const uint64_t*>& key_iters) {
117   auto output_tensor_data = std::make_unique<MutableVectorData<T>>();
118   output_tensor_data->reserve(key_iters.size());
119   for (const uint64_t* key_it : key_iters) {
120     const T* ptr = reinterpret_cast<const T*>(key_it);
121     output_tensor_data->push_back(*ptr);
122   }
123   return Tensor::Create(internal::TypeTraits<T>::kDataType,
124                         TensorShape{key_iters.size()},
125                         std::move(output_tensor_data));
126 }
127 
128 // Specialization of GetTensorForType for DT_STRING data type.
129 // Given a vector of char pointers, where the data pointed to can be safely
130 // interpreted as a pointer to a string, returns a tensor of type DT_STRING
131 // and the same length as the input vector containing these strings.
132 // The returned tensor will own all strings it refers to and is thus safe to
133 // use after this class is destroyed.
134 template <>
GetTensorForType(const std::vector<const uint64_t * > & key_iters)135 StatusOr<Tensor> GetTensorForType<string_view>(
136     const std::vector<const uint64_t*>& key_iters) {
137   std::vector<std::string> strings_for_output;
138   for (auto key_it = key_iters.begin(); key_it != key_iters.end(); ++key_it) {
139     const intptr_t* ptr_to_string_address =
140         reinterpret_cast<const intptr_t*>(*key_it);
141     // The integer stored to represent a string is the address of the string
142     // stored in the intern_pool_. Thus this integer can be safely cast to a
143     // pointer and dereferenced to obtain the string.
144     const std::string* ptr =
145         reinterpret_cast<const std::string*>(*ptr_to_string_address);
146     strings_for_output.push_back(*ptr);
147   }
148   return Tensor::Create(
149       DT_STRING, TensorShape{key_iters.size()},
150       std::make_unique<VectorStringData>(std::move(strings_for_output)));
151 }
152 
153 }  // namespace
154 
CompositeKeyCombiner(std::vector<DataType> dtypes)155 CompositeKeyCombiner::CompositeKeyCombiner(std::vector<DataType> dtypes)
156     : dtypes_(dtypes) {
157   for (DataType dtype : dtypes) {
158     // Initialize to false to satisfy compiler that all cases in the DTYPE_CASES
159     // switch statement are covered, even though the cases that don't result in
160     // a value for data_type_supported will actually crash the program.
161     bool data_type_supported = false;
162     DTYPE_CASES(dtype, T, data_type_supported = CheckDataTypeSupported<T>());
163     FCP_CHECK(data_type_supported)
164         << "Unsupported data type for CompositeKeyCombiner: " << dtype;
165   }
166 }
167 
168 // Returns a single tensor containing the ordinals of the composite keys
169 // formed from the InputTensorList.
Accumulate(const InputTensorList & tensors)170 StatusOr<Tensor> CompositeKeyCombiner::Accumulate(
171     const InputTensorList& tensors) {
172   FCP_ASSIGN_OR_RETURN(TensorShape shape, CheckValidAndGetShape(tensors));
173 
174   // Determine the serialized size of the composite keys.
175   size_t composite_key_size = sizeof(uint64_t) * tensors.size();
176 
177   std::vector<const void*> iterators;
178   iterators.reserve(tensors.size());
179   for (const Tensor* t : tensors) {
180     iterators.push_back(t->data().data());
181   }
182 
183   // Iterate over all the TensorDataIterators at once to get the value for the
184   // composite key.
185   auto ordinals = std::make_unique<MutableVectorData<int64_t>>();
186   for (int i = 0; i < shape.NumElements(); ++i) {
187     // Create a string with the correct amount of memory to store an int64
188     // representation of the element in each input tensor at the current
189     // index.
190     std::string composite_key_data(composite_key_size, '\0');
191     uint64_t* key_ptr = reinterpret_cast<uint64_t*>(composite_key_data.data());
192 
193     for (int j = 0; j < tensors.size(); ++j) {
194       // Copy the 64-bit representation of the element into the position in the
195       // composite key data corresponding to this tensor.
196       DTYPE_CASES(dtypes_[j], T,
197                   CopyToDest<T>(iterators[j], key_ptr++, intern_pool_));
198     }
199     auto [it, inserted] = composite_keys_.insert(
200         {std::move(composite_key_data), composite_key_next_});
201     if (inserted) {
202       // This is the first time this CompositeKeyCombiner has encountered this
203       // particular composite key.
204       composite_key_next_++;
205       // Save the string representation of the key in order to recover the
206       // elements of the key when GetOutputKeys is called.
207       key_vec_.push_back(it->first);
208     }
209     // Insert the ordinal representing the composite key into the
210     // correct position in the output tensor.
211     ordinals->push_back(it->second);
212   }
213   return Tensor::Create(internal::TypeTraits<int64_t>::kDataType, shape,
214                         std::move(ordinals));
215 }
216 
GetOutputKeys() const217 StatusOr<std::vector<Tensor>> CompositeKeyCombiner::GetOutputKeys() const {
218   std::vector<Tensor> output_keys;
219   // Creating empty tensors is not allowed, so if there are no keys yet,
220   // which could happen if GetOutputKeys is called before Accumulate, return
221   // an empty vector.
222   if (key_vec_.empty()) return output_keys;
223   // Otherwise reserve space for a tensor for each data type.
224   output_keys.reserve(dtypes_.size());
225   std::vector<const uint64_t*> key_iters;
226   key_iters.reserve(key_vec_.size());
227   for (string_view s : key_vec_) {
228     key_iters.push_back(reinterpret_cast<const uint64_t*>(s.data()));
229   }
230   for (DataType dtype : dtypes_) {
231     StatusOr<Tensor> t;
232     DTYPE_CASES(dtype, T, t = GetTensorForType<T>(key_iters));
233     FCP_RETURN_IF_ERROR(t.status());
234     output_keys.push_back(std::move(t.value()));
235     for (auto key_it = key_iters.begin(); key_it != key_iters.end(); ++key_it) {
236       ++*key_it;
237     }
238   }
239   return output_keys;
240 }
241 
CheckValidAndGetShape(const InputTensorList & tensors)242 StatusOr<TensorShape> CompositeKeyCombiner::CheckValidAndGetShape(
243     const InputTensorList& tensors) {
244   if (tensors.size() == 0) {
245     return FCP_STATUS(INVALID_ARGUMENT)
246            << "InputTensorList must contain at least one tensor.";
247   } else if (tensors.size() != dtypes_.size()) {
248     return FCP_STATUS(INVALID_ARGUMENT)
249            << "InputTensorList size " << tensors.size()
250            << "is not the same as the length of expected dtypes "
251            << dtypes_.size();
252   }
253   // All the tensors in the input list should have the same shape and have
254   // a dense encoding.
255   const TensorShape* shape = nullptr;
256   for (int i = 0; i < tensors.size(); ++i) {
257     const Tensor* t = tensors[i];
258     if (shape == nullptr) {
259       shape = &t->shape();
260     } else {
261       if (*shape != t->shape()) {
262         return FCP_STATUS(INVALID_ARGUMENT)
263                << "All tensors in the InputTensorList must have the expected "
264                   "shape.";
265       }
266     }
267     if (!t->is_dense())
268       return FCP_STATUS(INVALID_ARGUMENT)
269              << "All tensors in the InputTensorList must be dense.";
270     // Ensure the data types of the input tensors match those provided to the
271     // constructor of this CompositeKeyCombiner.
272     DataType expected_dtype = dtypes_[i];
273     if (expected_dtype != t->dtype()) {
274       return FCP_STATUS(INVALID_ARGUMENT)
275              << "Tensor did not have expected dtype " << expected_dtype
276              << " and instead had dtype " << t->dtype();
277     }
278   }
279   return *shape;
280 }
281 
282 }  // namespace aggregation
283 }  // namespace fcp
284