1 /*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/aggregation/core/composite_key_combiner.h"
17
18 #include <cstdint>
19 #include <cstring>
20 #include <memory>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26
27 #include "fcp/aggregation/core/datatype.h"
28 #include "fcp/aggregation/core/input_tensor_list.h"
29 #include "fcp/aggregation/core/mutable_vector_data.h"
30 #include "fcp/aggregation/core/tensor.h"
31 #include "fcp/aggregation/core/tensor.pb.h"
32 #include "fcp/aggregation/core/tensor_shape.h"
33 #include "fcp/aggregation/core/vector_string_data.h"
34 #include "fcp/base/monitoring.h"
35
36 namespace fcp {
37 namespace aggregation {
38
39 namespace {
40
41 template <typename T>
CheckDataTypeSupported()42 bool CheckDataTypeSupported() {
43 return sizeof(T) <= sizeof(uint64_t);
44 }
45
46 template <>
CheckDataTypeSupported()47 bool CheckDataTypeSupported<string_view>() {
48 // We will store the representation of a pointer to the string as an integer,
49 // so ensure the size of a pointer is less than or equal to the size of a
50 // 64-bit integer.
51 return sizeof(intptr_t) == sizeof(uint64_t);
52 }
53
54 // Copies the bytes pointed to by source_ptr to the destination pointed to by
55 // dest_ptr and advances source_ptr to the next T.
56 //
57 // The number of bytes copied will be the size of the type T.
58 //
59 // It is the responsibility of the caller to ensure that source_ptr is only used
60 // in subsequent code if it still points to a valid T after being incremented.
61 template <typename T>
CopyToDest(const void * & source_ptr,uint64_t * dest_ptr,std::unordered_set<std::string> & intern_pool)62 void CopyToDest(const void*& source_ptr, uint64_t* dest_ptr,
63 std::unordered_set<std::string>& intern_pool) {
64 auto typed_source_ptr = static_cast<const T*>(source_ptr);
65 // Cast the pointer to the destination storage to a pointer to type T and set
66 // the value it points to to the source value. This allows us to copy the
67 // number of bytes in T to the destination storage even if T is smaller than
68 // uint64_t, without copying extra bytes.
69 // It would be problematic if T is larger than uint64_t, but the Create method
70 // validated that this was not the case.
71 T* typed_dest_ptr = reinterpret_cast<T*>(dest_ptr);
72 *typed_dest_ptr = *typed_source_ptr;
73 // Set source_ptr to point to the next T assuming that it points to
74 // an array of T.
75 source_ptr = static_cast<const void*>(++typed_source_ptr);
76 }
77
78 // Specialization of CopyToDest for DT_STRING data type that interns the
79 // string_view pointed to by value_ptr. The address of the string in the
80 // intern pool is then converted to a 64 bit integer and copied to the
81 // destination pointed to by dest_ptr. Finally source_ptr is incremented to the
82 // next string_view.
83 //
84 // It is the responsibility of the caller to ensure that source_ptr is only used
85 // in subsequent code if it still points to a valid string_view after being
86 // incremented.
87 template <>
CopyToDest(const void * & source_ptr,uint64_t * dest_ptr,std::unordered_set<std::string> & intern_pool)88 void CopyToDest<string_view>(const void*& source_ptr, uint64_t* dest_ptr,
89 std::unordered_set<std::string>& intern_pool) {
90 auto string_view_ptr = static_cast<const string_view*>(source_ptr);
91 // Insert the string into the intern pool if it does not already exist. This
92 // makes a copy of the string so that the intern pool owns the storage.
93 auto it = intern_pool.emplace(*string_view_ptr).first;
94 // The iterator of an unordered set may be invalidated by inserting more
95 // elements, but the pointer to the underlying element is guaranteed to be
96 // stable. https://en.cppreference.com/w/cpp/container/unordered_set
97 // Thus, get the address of the string after dereferencing the iterator.
98 const std::string* interned_string_ptr = &*it;
99 // The stable address of the string can be interpreted as a 64-bit integer.
100 intptr_t ptr_int = reinterpret_cast<intptr_t>(interned_string_ptr);
101 // Set the destination storage to the integer representation of the string
102 // address.
103 *dest_ptr = static_cast<uint64_t>(ptr_int);
104 // Set the source_ptr to point to the next string_view assuming that it points
105 // to an array of string_view.
106 source_ptr = static_cast<const void*>(++string_view_ptr);
107 }
108
109 // Given a vector of uint64_t pointers, where the data pointed to can be safely
110 // interpreted as type T, returns a Tensor of underlying data type
111 // corresponding to T and the same length as the input vector. Each element of
112 // the tensor is created by interpreting the data pointed to by the uint64_t
113 // pointer at that index as type T.
114 template <typename T>
GetTensorForType(const std::vector<const uint64_t * > & key_iters)115 StatusOr<Tensor> GetTensorForType(
116 const std::vector<const uint64_t*>& key_iters) {
117 auto output_tensor_data = std::make_unique<MutableVectorData<T>>();
118 output_tensor_data->reserve(key_iters.size());
119 for (const uint64_t* key_it : key_iters) {
120 const T* ptr = reinterpret_cast<const T*>(key_it);
121 output_tensor_data->push_back(*ptr);
122 }
123 return Tensor::Create(internal::TypeTraits<T>::kDataType,
124 TensorShape{key_iters.size()},
125 std::move(output_tensor_data));
126 }
127
128 // Specialization of GetTensorForType for DT_STRING data type.
129 // Given a vector of char pointers, where the data pointed to can be safely
130 // interpreted as a pointer to a string, returns a tensor of type DT_STRING
131 // and the same length as the input vector containing these strings.
132 // The returned tensor will own all strings it refers to and is thus safe to
133 // use after this class is destroyed.
134 template <>
GetTensorForType(const std::vector<const uint64_t * > & key_iters)135 StatusOr<Tensor> GetTensorForType<string_view>(
136 const std::vector<const uint64_t*>& key_iters) {
137 std::vector<std::string> strings_for_output;
138 for (auto key_it = key_iters.begin(); key_it != key_iters.end(); ++key_it) {
139 const intptr_t* ptr_to_string_address =
140 reinterpret_cast<const intptr_t*>(*key_it);
141 // The integer stored to represent a string is the address of the string
142 // stored in the intern_pool_. Thus this integer can be safely cast to a
143 // pointer and dereferenced to obtain the string.
144 const std::string* ptr =
145 reinterpret_cast<const std::string*>(*ptr_to_string_address);
146 strings_for_output.push_back(*ptr);
147 }
148 return Tensor::Create(
149 DT_STRING, TensorShape{key_iters.size()},
150 std::make_unique<VectorStringData>(std::move(strings_for_output)));
151 }
152
153 } // namespace
154
CompositeKeyCombiner(std::vector<DataType> dtypes)155 CompositeKeyCombiner::CompositeKeyCombiner(std::vector<DataType> dtypes)
156 : dtypes_(dtypes) {
157 for (DataType dtype : dtypes) {
158 // Initialize to false to satisfy compiler that all cases in the DTYPE_CASES
159 // switch statement are covered, even though the cases that don't result in
160 // a value for data_type_supported will actually crash the program.
161 bool data_type_supported = false;
162 DTYPE_CASES(dtype, T, data_type_supported = CheckDataTypeSupported<T>());
163 FCP_CHECK(data_type_supported)
164 << "Unsupported data type for CompositeKeyCombiner: " << dtype;
165 }
166 }
167
168 // Returns a single tensor containing the ordinals of the composite keys
169 // formed from the InputTensorList.
Accumulate(const InputTensorList & tensors)170 StatusOr<Tensor> CompositeKeyCombiner::Accumulate(
171 const InputTensorList& tensors) {
172 FCP_ASSIGN_OR_RETURN(TensorShape shape, CheckValidAndGetShape(tensors));
173
174 // Determine the serialized size of the composite keys.
175 size_t composite_key_size = sizeof(uint64_t) * tensors.size();
176
177 std::vector<const void*> iterators;
178 iterators.reserve(tensors.size());
179 for (const Tensor* t : tensors) {
180 iterators.push_back(t->data().data());
181 }
182
183 // Iterate over all the TensorDataIterators at once to get the value for the
184 // composite key.
185 auto ordinals = std::make_unique<MutableVectorData<int64_t>>();
186 for (int i = 0; i < shape.NumElements(); ++i) {
187 // Create a string with the correct amount of memory to store an int64
188 // representation of the element in each input tensor at the current
189 // index.
190 std::string composite_key_data(composite_key_size, '\0');
191 uint64_t* key_ptr = reinterpret_cast<uint64_t*>(composite_key_data.data());
192
193 for (int j = 0; j < tensors.size(); ++j) {
194 // Copy the 64-bit representation of the element into the position in the
195 // composite key data corresponding to this tensor.
196 DTYPE_CASES(dtypes_[j], T,
197 CopyToDest<T>(iterators[j], key_ptr++, intern_pool_));
198 }
199 auto [it, inserted] = composite_keys_.insert(
200 {std::move(composite_key_data), composite_key_next_});
201 if (inserted) {
202 // This is the first time this CompositeKeyCombiner has encountered this
203 // particular composite key.
204 composite_key_next_++;
205 // Save the string representation of the key in order to recover the
206 // elements of the key when GetOutputKeys is called.
207 key_vec_.push_back(it->first);
208 }
209 // Insert the ordinal representing the composite key into the
210 // correct position in the output tensor.
211 ordinals->push_back(it->second);
212 }
213 return Tensor::Create(internal::TypeTraits<int64_t>::kDataType, shape,
214 std::move(ordinals));
215 }
216
GetOutputKeys() const217 StatusOr<std::vector<Tensor>> CompositeKeyCombiner::GetOutputKeys() const {
218 std::vector<Tensor> output_keys;
219 // Creating empty tensors is not allowed, so if there are no keys yet,
220 // which could happen if GetOutputKeys is called before Accumulate, return
221 // an empty vector.
222 if (key_vec_.empty()) return output_keys;
223 // Otherwise reserve space for a tensor for each data type.
224 output_keys.reserve(dtypes_.size());
225 std::vector<const uint64_t*> key_iters;
226 key_iters.reserve(key_vec_.size());
227 for (string_view s : key_vec_) {
228 key_iters.push_back(reinterpret_cast<const uint64_t*>(s.data()));
229 }
230 for (DataType dtype : dtypes_) {
231 StatusOr<Tensor> t;
232 DTYPE_CASES(dtype, T, t = GetTensorForType<T>(key_iters));
233 FCP_RETURN_IF_ERROR(t.status());
234 output_keys.push_back(std::move(t.value()));
235 for (auto key_it = key_iters.begin(); key_it != key_iters.end(); ++key_it) {
236 ++*key_it;
237 }
238 }
239 return output_keys;
240 }
241
CheckValidAndGetShape(const InputTensorList & tensors)242 StatusOr<TensorShape> CompositeKeyCombiner::CheckValidAndGetShape(
243 const InputTensorList& tensors) {
244 if (tensors.size() == 0) {
245 return FCP_STATUS(INVALID_ARGUMENT)
246 << "InputTensorList must contain at least one tensor.";
247 } else if (tensors.size() != dtypes_.size()) {
248 return FCP_STATUS(INVALID_ARGUMENT)
249 << "InputTensorList size " << tensors.size()
250 << "is not the same as the length of expected dtypes "
251 << dtypes_.size();
252 }
253 // All the tensors in the input list should have the same shape and have
254 // a dense encoding.
255 const TensorShape* shape = nullptr;
256 for (int i = 0; i < tensors.size(); ++i) {
257 const Tensor* t = tensors[i];
258 if (shape == nullptr) {
259 shape = &t->shape();
260 } else {
261 if (*shape != t->shape()) {
262 return FCP_STATUS(INVALID_ARGUMENT)
263 << "All tensors in the InputTensorList must have the expected "
264 "shape.";
265 }
266 }
267 if (!t->is_dense())
268 return FCP_STATUS(INVALID_ARGUMENT)
269 << "All tensors in the InputTensorList must be dense.";
270 // Ensure the data types of the input tensors match those provided to the
271 // constructor of this CompositeKeyCombiner.
272 DataType expected_dtype = dtypes_[i];
273 if (expected_dtype != t->dtype()) {
274 return FCP_STATUS(INVALID_ARGUMENT)
275 << "Tensor did not have expected dtype " << expected_dtype
276 << " and instead had dtype " << t->dtype();
277 }
278 }
279 return *shape;
280 }
281
282 } // namespace aggregation
283 } // namespace fcp
284