xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/composite_key_combiner.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_
18 #define FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_
19 
20 #include <cstdint>
21 #include <string>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <vector>
25 
26 #include "fcp/aggregation/core/datatype.h"
27 #include "fcp/aggregation/core/input_tensor_list.h"
28 #include "fcp/aggregation/core/tensor.h"
29 #include "fcp/aggregation/core/tensor.pb.h"
30 #include "fcp/aggregation/core/tensor_shape.h"
31 #include "fcp/base/monitoring.h"
32 
33 namespace fcp {
34 namespace aggregation {
35 
36 // Class operating on sets of tensors of the same shape to combine indices for
37 // which the same combination of elements occurs, or in other words, indices
38 // containing the same composite key.
39 //
40 // This class contains two methods: Accumulate and GetOutputKeys, which can each
41 // be called multiple times.
42 //
43 // Accumulate takes in an InputTensorList of tensors of the same shape, and
44 // returns a Tensor of the same shape containing ordinals to represent the
45 // composite key that exists at each index. Composite keys are stored
46 // across calls to Accumulate, so if the same composite key is ever encountered
47 // in two different indices, whether in the same or a different call to
48 // Accumulate, the same ordinal will be returned in both these indices.
49 //
50 // GetOutputKeys returns the composite keys that have been seen in all previous
51 // calls to Accumulate, represented by a vector of Tensors. If the ordinal
52 // returned by Accumulate for that composite key was i, the composite key will
53 // be found at position i in the output vector.
54 //
55 // This class is not threadsafe.
56 class CompositeKeyCombiner {
57  public:
58   ~CompositeKeyCombiner() = default;
59 
60   // CompositeKeyCombiner is not copyable or moveable.
61   CompositeKeyCombiner(const CompositeKeyCombiner&) = delete;
62   CompositeKeyCombiner& operator=(const CompositeKeyCombiner&) = delete;
63   CompositeKeyCombiner(CompositeKeyCombiner&&) = delete;
64   CompositeKeyCombiner& operator=(CompositeKeyCombiner&&) = delete;
65 
66   // Creates a CompositeKeyCombiner if inputs are valid or crashes otherwise.
67   explicit CompositeKeyCombiner(std::vector<DataType> dtypes);
68 
69   // Returns a single tensor containing the ordinals of the composite keys
70   // formed from the tensors in the InputTensorList.
71   //
72   // The shape of each of the input tensors must match the shape provided to the
73   // constructor, and the dtypes of the input tensors must match the dtypes
74   // provided to the constructor.
75   //
76   // For each index in the input tensors, the combination of elements from each
77   // tensor at that index forms a "composite key." Across calls to Accumulate,
78   // each unique composite key will be represented by a unique ordinal.
79   //
80   // The returned tensor is of data type DT_INT64 and the same shape that was
81   // provided to the constructor.
82   StatusOr<Tensor> Accumulate(const InputTensorList& tensors);
83 
84   // Obtains the vector of output keys ordered by their representative ordinal.
85   //
86   // The datatypes of the tensors in the output vector will match the data types
87   // provided to the constructor.
88   //
89   // For each unique combination of elements that was seen across all calls to
90   // Accumulate on this class so far, the vector of output tensors will include
91   // that combination of elements. The ordering of the elements within the
92   // output tensors will correspond to the ordinals returned by Accumulate. For
93   // example, if Accumulate returned the integer 5 in the output tensor at
94   // position 8 when it encountered this combination of elements in the input
95   // tensor list at position 8, then the elements in the composite key will
96   // appear at position 5 in the output tensors returned by this method.
97   StatusOr<std::vector<Tensor>> GetOutputKeys() const;
98 
99  private:
100   // Checks that the provided InputTensorList can be accumulated into this
101   // CompositeKeyCombiner.
102   StatusOr<TensorShape> CheckValidAndGetShape(const InputTensorList& tensors);
103 
104   // The data types of the tensors in valid inputs to Accumulate, in this exact
105   // order.
106   // TODO(team): Use inlined vector to store the DataTypes instead.
107   std::vector<DataType> dtypes_;
108   // String views of the composite keys in the order the keys will appear in the
109   // output tensors returned by GetOutputKeys.
110   std::vector<string_view> key_vec_;
111   // Set of unique strings encountered in tensors of type DT_STRING on calls to
112   // Accumulate.
113   // Used as an optimization to avoid storing the same string multiple
114   // times even if it appears in many composite keys.
115   // TODO(team): Intern directly into the output tensor instead to avoid
116   // copies when creating the output tensors.
117   std::unordered_set<std::string> intern_pool_;
118   // Mapping of string representations of the composite keys seen so far to
119   // their ordinal position in the output tensors returned by GetOutputKeys.
120   std::unordered_map<std::string, int64_t> composite_keys_;
121   // Number of unique composite keys encountered so far across all calls to
122   // Accumulate.
123   int64_t composite_key_next_ = 0;
124 };
125 
126 }  // namespace aggregation
127 }  // namespace fcp
128 
129 #endif  // FCP_AGGREGATION_CORE_COMPOSITE_KEY_COMBINER_H_
130