xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/mapped_ptr_container_sorter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Below, we specify an example usage, in which clone is sorted according to
17 // original, using map_fn to map from pointers in original to pointers in clone.
18 //
19 //   std::vector<std::unique_ptr<HloInstruction*>> original = ...;
20 //   std::vector<std::unique_ptr<HloInstruction*>> clone = ...;
21 //   HloCloneContext* ctx = ...;
22 //   using Sorter = MappedPtrContainerSorter<HloInstruction>;
23 //   Sorter::MappedPtrFn map_fn = [ctx](const HloInstruction* i) {
24 //       return ctx->FindInstruction(i);
25 //     };
26 //
27 //   auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(),
28 //                              original, clone);
29 
30 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
31 #define TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
32 
33 #include <array>
34 #include <cstddef>
35 #include <functional>
36 #include <limits>
37 #include <list>
38 #include <memory>
39 #include <string>
40 #include <utility>
41 #include <vector>
42 
43 #include "absl/container/flat_hash_map.h"
44 #include "absl/container/flat_hash_set.h"
45 #include "absl/strings/str_cat.h"
46 #include "absl/strings/str_join.h"
47 #include "tensorflow/compiler/xla/status.h"
48 #include "tensorflow/compiler/xla/statusor.h"
49 #include "tensorflow/compiler/xla/util.h"
50 #include "tensorflow/core/platform/errors.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/statusor.h"
53 
54 namespace xla {
55 
56 // A class for sorting an unordered container of pointers according to the sort
57 // order of an ordered container of pointers. Sorting is stable.
58 //
59 // Terminology:
60 // - unmapped element: An element from the unordered container that does not
61 //   have a corresponding element in the ordered container.
62 template <typename PointedToTy>
63 class MappedPtrContainerSorter {
64  public:
65   // A function to map elements from an ordered container to elements in an
66   // unordered container. Not every element in ordered_container need map to an
67   // element in unordered_container and vice versa.
68   using MapPtrFn = std::function<const PointedToTy*(const PointedToTy*)>;
69 
70   // A function that maps unmapped elements (from an unordered container) to an
71   // index in the final sorted result. The returned index indicates that the
72   // unmapped element should be placed just after the mapped element at that
73   // index, in the result without unmapped elements. See
74   // IndexBeforeMappedElementsFn() and IndexAfterMappedElementsFn() for how to
75   // indicate that an unmapped element should be placed before or after all
76   // mapped elements, respectively. Unmapped elements destined for the same
77   // index will retain their order from the unordered container.
78   using UnmappedPtrIndexFn = std::function<size_t(const PointedToTy*)>;
79 
80   // Functions that return an UnmappedElementIndexFn that indicates that
81   // ummapped elements (from an unordered container) should be placed before or
82   // after all mapped elements, respectively.
83   static const UnmappedPtrIndexFn& IndexBeforeMappedElementsFn();
84   static const UnmappedPtrIndexFn& IndexAfterMappedElementsFn();
85 
86   // Returned function always returns an error.
87   static const UnmappedPtrIndexFn& InvalidIndexFn();
88 
89   // Sorts an unordered container of pointers according to the order of an
90   // ordered container of pointers. Sorting is stable. Works with POD pointers,
91   // const POD pointers, and unique_ptrs. If an error is returned,
92   // unordered_container is not modified. Returns an error status if:
93   // - unmapped_index() returns an invalid index
94   // - An internal error occurs. (This should theoretically not happen.)
95   template <typename OrderedTy, typename UnorderedTy>
96   static Status Sort(const MapPtrFn& map_ptr,
97                      const UnmappedPtrIndexFn& unmapped_index,
98                      const OrderedTy& ordered_container,
99                      UnorderedTy& unordered_container);
100 
101  private:
102   // A class for sorting the indices of the unordered_container.
103   class SortedIndices {
104    public:
105     // max_partial_order_exclusive is 1 greater than the maximum partial order
106     // value allowed to be sent to AddMappedElement().
SortedIndices(size_t max_partial_order_exclusive,size_t unordered_container_size)107     SortedIndices(size_t max_partial_order_exclusive,
108                   size_t unordered_container_size)
109         : max_partial_order_exclusive_(max_partial_order_exclusive),
110           unordered_container_size_(unordered_container_size),
111           mapped_element_indices_by_partial_order_(
112               max_partial_order_exclusive) {}
113 
114     // Specify the partial ordering value of a mapped element from the
115     // unordered container. The partial ordering is amongst other mapped
116     // elements.
117     Status AddMappedElement(size_t unordered_container_index,
118                             size_t partial_order);
119 
120     // Specify the index (amongst mapped elements), where an unmapped element
121     // should be inserted. The unmapped element is inserted just after the
122     // mapped element with index target_index_amongst_mapped_elements.
123     void AddUnmappedElement(size_t unordered_container_index,
124                             size_t target_index_amongst_mapped_elements);
125 
126     std::string ToString() const;
127 
128     // The result maps each element in the unordered_container to the target
129     // index that it will occupy in the sorted result.
130     StatusOr<std::vector<size_t>> Flatten() const;
131 
132    private:
133     SortedIndices() = delete;
134 
135     size_t max_partial_order_exclusive_;
136     size_t unordered_container_size_;
137     std::vector<std::vector<size_t>> mapped_element_indices_by_partial_order_;
138     absl::flat_hash_map<size_t, std::vector<size_t>>
139         target_index_to_unmapped_element_index_;
140   };
141 
IndexBeforeMappedElements()142   static size_t IndexBeforeMappedElements() {
143     return std::numeric_limits<size_t>::max() - 2;
144   }
145 
IndexAfterMappedElements()146   static size_t IndexAfterMappedElements() {
147     return std::numeric_limits<size_t>::max() - 1;
148   }
149 
InvalidIndex()150   static size_t InvalidIndex() { return std::numeric_limits<size_t>::max(); }
151 
152   // Returns a mapping in which the element at index i indicates the target
153   // index that unordered_container[i] should occupy in the sorted result.
154   template <typename OrderedTy, typename UnorderedTy>
155   static StatusOr<std::vector<size_t>> ComputeNewIndices(
156       const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
157       const OrderedTy& ordered_container,
158       const UnorderedTy& unordered_container);
159 
160   // Reorders unordered_container according to the indices in new_indices. See
161   // ComputeNewIndices() for how to interpret new_indices.
162   template <typename UnorderedTy>
163   static void Reorder(std::vector<size_t> new_indices,
164                       UnorderedTy& unordered_container);
165 };
166 
167 ///// Template implementation below /////
168 
169 namespace mapped_ptr_container_sorter_internal {
170 
171 template <typename I, typename O>
172 struct PtrGetter {
173   // Extracts a pointer of type O from i.
174   static O Get(I i);
175 };
176 
177 template <typename T>
178 struct PtrGetter<T* const&, const T*> {
179   static const T* Get(T* const& p) { return p; }
180 };
181 
182 template <typename T>
183 struct PtrGetter<T const* const&, const T*> {
184   static const T* Get(T const* const& p) { return p; }
185 };
186 
187 template <typename T>
188 struct PtrGetter<T*&, T*> {
189   static T* Get(T*& p) { return p; }
190 };
191 
192 template <typename T>
193 struct PtrGetter<const std::unique_ptr<T>&, const T*> {
194   static const T* Get(const std::unique_ptr<T>& p) { return p.get(); }
195 };
196 
197 template <typename T>
198 struct PtrGetter<std::unique_ptr<T>&, T*> {
199   static T* Get(std::unique_ptr<T>& p) { return p.get(); }
200 };
201 
202 }  // namespace mapped_ptr_container_sorter_internal
203 
204 template <typename PointedToTy>
205 const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
206 MappedPtrContainerSorter<PointedToTy>::IndexBeforeMappedElementsFn() {
207   static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn(
208       [](const PointedToTy*) { return IndexBeforeMappedElements(); });
209   return *fn;
210 }
211 
212 template <typename PointedToTy>
213 const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
214 MappedPtrContainerSorter<PointedToTy>::IndexAfterMappedElementsFn() {
215   static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn(
216       [](const PointedToTy*) { return IndexAfterMappedElements(); });
217   return *fn;
218 }
219 
220 template <typename PointedToTy>
221 const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn&
222 MappedPtrContainerSorter<PointedToTy>::InvalidIndexFn() {
223   static const UnmappedPtrIndexFn* fn =
224       new UnmappedPtrIndexFn([](const PointedToTy*) { return InvalidIndex(); });
225   return *fn;
226 }
227 
228 template <typename PointedToTy>
229 Status MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddMappedElement(
230     size_t unordered_container_index, size_t partial_order) {
231   if (partial_order >= mapped_element_indices_by_partial_order_.size()) {
232     return InternalErrorStrCat(
233         "invalid partial order: ", partial_order, " v max(",
234         mapped_element_indices_by_partial_order_.size(), ")");
235   }
236 
237   mapped_element_indices_by_partial_order_[partial_order].push_back(
238       unordered_container_index);
239   return Status::OK();
240 }
241 
242 template <typename PointedToTy>
243 void MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddUnmappedElement(
244     size_t unordered_container_index,
245     size_t target_index_amongst_mapped_elements) {
246   target_index_to_unmapped_element_index_[target_index_amongst_mapped_elements]
247       .push_back(unordered_container_index);
248 }
249 
250 template <typename PointedToTy>
251 std::string MappedPtrContainerSorter<PointedToTy>::SortedIndices::ToString()
252     const {
253   std::vector<std::string> mapped_element_strs;
254   mapped_element_strs.reserve(mapped_element_indices_by_partial_order_.size());
255   for (const auto& indices : mapped_element_indices_by_partial_order_) {
256     mapped_element_strs.push_back(
257         absl::StrCat("[", absl::StrJoin(indices, ", "), "]"));
258   }
259   std::vector<std::string> unmapped_element_strs;
260   unmapped_element_strs.reserve(target_index_to_unmapped_element_index_.size());
261   for (const auto& kv : target_index_to_unmapped_element_index_) {
262     std::string key = absl::StrCat(kv.first);
263     if (kv.first == IndexBeforeMappedElements()) {
264       key = "before_mapped";
265     }
266     if (kv.first == IndexAfterMappedElements()) {
267       key = "after_mapped";
268     }
269     if (kv.first == InvalidIndex()) {
270       key = "invalid";
271     }
272     unmapped_element_strs.push_back(
273         absl::StrCat(key, ": [", absl::StrJoin(kv.second, ", "), "]"));
274   }
275 
276   return absl::StrCat(
277       "max_partial_order_exclusive_: ", max_partial_order_exclusive_, "\n",
278       "unordered_container_size_: ", unordered_container_size_, "\n",
279       "mapped_element_indices_by_partial_order_: [",
280       absl::StrJoin(mapped_element_strs, ", "), "]\n",
281       "target_index_to_unmapped_element_index_: {",
282       absl::StrJoin(unmapped_element_strs, ", "), "}\n");
283 }
284 
285 template <typename PointedToTy>
286 StatusOr<std::vector<size_t>>
287 MappedPtrContainerSorter<PointedToTy>::SortedIndices::Flatten() const {
288   std::vector<size_t> result(unordered_container_size_, InvalidIndex());
289   size_t next_available_index = 0;
290   auto next_index_fn = [&]() -> StatusOr<size_t> {
291     if (next_available_index >= unordered_container_size_) {
292       return InternalErrorStrCat(
293           "invalid unordered_container index: ", next_available_index,
294           " v size(", unordered_container_size_, ")");
295     }
296     return next_available_index++;
297   };
298 
299   if (target_index_to_unmapped_element_index_.contains(
300           IndexBeforeMappedElements())) {
301     const auto& indices =
302         target_index_to_unmapped_element_index_.at(IndexBeforeMappedElements());
303     for (size_t index : indices) {
304       TF_ASSIGN_OR_RETURN(result[index], next_index_fn());
305     }
306   }
307   size_t num_inserted_mapped_elements = 0;
308   for (const auto& mapped_element_indices :
309        mapped_element_indices_by_partial_order_) {
310     for (size_t mapped_element_index : mapped_element_indices) {
311       TF_ASSIGN_OR_RETURN(result[mapped_element_index], next_index_fn());
312       ++num_inserted_mapped_elements;
313       if (target_index_to_unmapped_element_index_.contains(
314               num_inserted_mapped_elements - 1)) {
315         const auto& unmapped_element_indices =
316             target_index_to_unmapped_element_index_.at(
317                 num_inserted_mapped_elements - 1);
318         for (size_t unmapped_element_index : unmapped_element_indices) {
319           TF_ASSIGN_OR_RETURN(result[unmapped_element_index], next_index_fn());
320         }
321       }
322     }
323   }
324   if (target_index_to_unmapped_element_index_.contains(
325           IndexAfterMappedElements())) {
326     const auto& indices =
327         target_index_to_unmapped_element_index_.at(IndexAfterMappedElements());
328     for (size_t index : indices) {
329       TF_ASSIGN_OR_RETURN(result[index], next_index_fn());
330     }
331   }
332 
333   // Ensure that every element in unordered_container has a valid new index.
334   absl::flat_hash_set<size_t> used_indices;
335   for (size_t index : result) {
336     if (used_indices.contains(index)) {
337       return InternalErrorStrCat(
338           "2 elements in unordered_container are destined for the same "
339           "index: ",
340           index);
341     }
342     if (index >= unordered_container_size_) {
343       return InvalidArgumentStrCat("invalid unordered_container index: ", index,
344                                    " v size(", unordered_container_size_, ")");
345     }
346   }
347 
348   return result;
349 }
350 
351 template <typename PointedToTy>
352 template <typename OrderedTy, typename UnorderedTy>
353 StatusOr<std::vector<size_t>>
354 MappedPtrContainerSorter<PointedToTy>::ComputeNewIndices(
355     const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
356     const OrderedTy& ordered_container,
357     const UnorderedTy& unordered_container) {
358   using UnorderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter<
359       typename UnorderedTy::const_reference, const PointedToTy*>;
360   using OrderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter<
361       typename OrderedTy::const_reference, const PointedToTy*>;
362 
363   if (unordered_container.size() >= IndexBeforeMappedElements()) {
364     return InvalidArgumentStrCat("Unordered container is too large to sort.");
365   }
366 
367   // Step 1: build a set of the ptrs in unordered_container
368   absl::flat_hash_set<const PointedToTy*> unordered_ptrs;
369   for (const auto& unordered_element : unordered_container) {
370     const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_element);
371     unordered_ptrs.insert(ptr);
372   }
373 
374   // Step 2: for mapped elements (in unordered_container), create a map from
375   // mapped ptr -> partial ordering
376   absl::flat_hash_map<const PointedToTy*, std::list<size_t>>
377       mapped_ptr_to_partial_order;
378   size_t next_partial_order_value = 0;
379   for (const auto& ordered_element : ordered_container) {
380     const PointedToTy* ordered_ptr = OrderedPtrGetter::Get(ordered_element);
381     const PointedToTy* unordered_ptr = map_ptr(ordered_ptr);
382     if (!unordered_ptr) {
383       // A corresponding unordered element does not exist.
384       continue;
385     }
386     if (!unordered_ptrs.contains(unordered_ptr)) {
387       // A pointer exists that maps to the ordered element, but it's not in our
388       // unordered_container.
389       continue;
390     }
391     mapped_ptr_to_partial_order[unordered_ptr].push_back(
392         next_partial_order_value);
393     ++next_partial_order_value;
394   }
395 
396   // Step 3: create sorted unordered element indices
397   SortedIndices result(next_partial_order_value, unordered_container.size());
398   for (size_t i = 0; i < unordered_container.size(); ++i) {
399     const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_container[i]);
400     if (!mapped_ptr_to_partial_order.contains(ptr)) {
401       // ptr is unmapped
402       result.AddUnmappedElement(i, unmapped_index(ptr));
403       continue;
404     }
405 
406     // ptr is mapped
407     //
408     // Potentially, several elements in ordered_container map to ptr.
409     // We assign ptr theindex corresponding to the next such ordered element.
410     auto& index_list = mapped_ptr_to_partial_order[ptr];
411     TF_RETURN_IF_ERROR(result.AddMappedElement(i, index_list.front()));
412     // Do not map more than one unordered element to the same index, unless we
413     // have no choice.
414     if (index_list.size() > 1) {
415       // We never remove the last ordered index, in case ptr appears in the
416       // unordered_container more times than the ordered container.
417       index_list.pop_front();
418     }
419   }
420 
421   VLOG(5) << "Pre flatten unordered_container result:\n" << result.ToString();
422   return result.Flatten();
423 }
424 
425 template <typename PointedToTy>
426 template <typename UnorderedTy>
427 void MappedPtrContainerSorter<PointedToTy>::Reorder(
428     std::vector<size_t> new_indices, UnorderedTy& unordered_container) {
429   size_t old_pos = 0;
430   while (old_pos < new_indices.size()) {
431     size_t new_pos = new_indices[old_pos];
432     if (old_pos == new_pos) {
433       ++old_pos;
434       continue;
435     }
436     std::swap(new_indices[old_pos], new_indices[new_pos]);
437     std::swap(unordered_container[old_pos], unordered_container[new_pos]);
438   }
439 }
440 
441 template <typename PointedToTy>
442 template <typename OrderedTy, typename UnorderedTy>
443 Status MappedPtrContainerSorter<PointedToTy>::Sort(
444     const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index,
445     const OrderedTy& ordered_container, UnorderedTy& unordered_container) {
446   std::vector<size_t> indices;
447   TF_ASSIGN_OR_RETURN(
448       indices, ComputeNewIndices(map_ptr, unmapped_index, ordered_container,
449                                  unordered_container));
450   Reorder(std::move(indices), unordered_container);
451   return Status::OK();
452 }
453 
454 }  // namespace xla
455 
456 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_
457