1 /* 2 * Copyright 2023 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_ 18 #define FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_ 19 20 #include <cstdint> 21 #include <string> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <vector> 25 26 #include "fcp/aggregation/core/datatype.h" 27 #include "fcp/aggregation/core/input_tensor_list.h" 28 #include "fcp/aggregation/core/tensor.h" 29 #include "fcp/aggregation/core/tensor.pb.h" 30 #include "fcp/aggregation/core/tensor_shape.h" 31 #include "fcp/base/monitoring.h" 32 33 namespace fcp { 34 namespace aggregation { 35 36 // Class operating on sets of tensors of the same shape to combine indices for 37 // which the same combination of elements occurs, or in other words, indices 38 // containing the same composite key. 39 // 40 // This class contains two methods: Accumulate and GetOutputKeys, which can each 41 // be called multiple times. 42 // 43 // Accumulate takes in an InputTensorList of tensors of the same shape, and 44 // returns a Tensor of the same shape containing ordinals to represent the 45 // composite key that exists at each index. Composite keys are stored 46 // across calls to Accumulate, so if the same composite key is ever encountered 47 // in two different indices, whether in the same or a different call to 48 // Accumulate, the same ordinal will be returned in both these indices. 49 // 50 // GetOutputKeys returns the composite keys that have been seen in all previous 51 // calls to Accumulate, represented by a vector of Tensors. If the ordinal 52 // returned by Accumulate for that composite key was i, the composite key will 53 // be found at position i in the output vector. 54 // 55 // This class is not threadsafe. 56 class CompositeKeyCombiner { 57 public: 58 ~CompositeKeyCombiner() = default; 59 60 // CompositeKeyCombiner is not copyable or moveable. 61 CompositeKeyCombiner(const CompositeKeyCombiner&) = delete; 62 CompositeKeyCombiner& operator=(const CompositeKeyCombiner&) = delete; 63 CompositeKeyCombiner(CompositeKeyCombiner&&) = delete; 64 CompositeKeyCombiner& operator=(CompositeKeyCombiner&&) = delete; 65 66 // Creates a CompositeKeyCombiner if inputs are valid or crashes otherwise. 67 explicit CompositeKeyCombiner(std::vector<DataType> dtypes); 68 69 // Returns a single tensor containing the ordinals of the composite keys 70 // formed from the tensors in the InputTensorList. 71 // 72 // The shape of each of the input tensors must match the shape provided to the 73 // constructor, and the dtypes of the input tensors must match the dtypes 74 // provided to the constructor. 75 // 76 // For each index in the input tensors, the combination of elements from each 77 // tensor at that index forms a "composite key." Across calls to Accumulate, 78 // each unique composite key will be represented by a unique ordinal. 79 // 80 // The returned tensor is of data type DT_INT64 and the same shape that was 81 // provided to the constructor. 82 StatusOr<Tensor> Accumulate(const InputTensorList& tensors); 83 84 // Obtains the vector of output keys ordered by their representative ordinal. 85 // 86 // The datatypes of the tensors in the output vector will match the data types 87 // provided to the constructor. 88 // 89 // For each unique combination of elements that was seen across all calls to 90 // Accumulate on this class so far, the vector of output tensors will include 91 // that combination of elements. The ordering of the elements within the 92 // output tensors will correspond to the ordinals returned by Accumulate. For 93 // example, if Accumulate returned the integer 5 in the output tensor at 94 // position 8 when it encountered this combination of elements in the input 95 // tensor list at position 8, then the elements in the composite key will 96 // appear at position 5 in the output tensors returned by this method. 97 StatusOr<std::vector<Tensor>> GetOutputKeys() const; 98 99 private: 100 // Checks that the provided InputTensorList can be accumulated into this 101 // CompositeKeyCombiner. 102 StatusOr<TensorShape> CheckValidAndGetShape(const InputTensorList& tensors); 103 104 // The data types of the tensors in valid inputs to Accumulate, in this exact 105 // order. 106 // TODO(team): Use inlined vector to store the DataTypes instead. 107 std::vector<DataType> dtypes_; 108 // String views of the composite keys in the order the keys will appear in the 109 // output tensors returned by GetOutputKeys. 110 std::vector<string_view> key_vec_; 111 // Set of unique strings encountered in tensors of type DT_STRING on calls to 112 // Accumulate. 113 // Used as an optimization to avoid storing the same string multiple 114 // times even if it appears in many composite keys. 115 // TODO(team): Intern directly into the output tensor instead to avoid 116 // copies when creating the output tensors. 117 std::unordered_set<std::string> intern_pool_; 118 // Mapping of string representations of the composite keys seen so far to 119 // their ordinal position in the output tensors returned by GetOutputKeys. 120 std::unordered_map<std::string, int64_t> composite_keys_; 121 // Number of unique composite keys encountered so far across all calls to 122 // Accumulate. 123 int64_t composite_key_next_ = 0; 124 }; 125 126 } // namespace aggregation 127 } // namespace fcp 128 129 #endif // FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_ 130