1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Below, we specify an example usage, in which clone is sorted according to 17 // original, using map_fn to map from pointers in original to pointers in clone. 18 // 19 // std::vector<std::unique_ptr<HloInstruction*>> original = ...; 20 // std::vector<std::unique_ptr<HloInstruction*>> clone = ...; 21 // HloCloneContext* ctx = ...; 22 // using Sorter = MappedPtrContainerSorter<HloInstruction>; 23 // Sorter::MappedPtrFn map_fn = [ctx](const HloInstruction* i) { 24 // return ctx->FindInstruction(i); 25 // }; 26 // 27 // auto status = Sorter::Sort(map_fn, Sorter::IndexAfterMappedElementsFn(), 28 // original, clone); 29 30 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_ 31 #define TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_ 32 33 #include <array> 34 #include <cstddef> 35 #include <functional> 36 #include <limits> 37 #include <list> 38 #include <memory> 39 #include <string> 40 #include <utility> 41 #include <vector> 42 43 #include "absl/container/flat_hash_map.h" 44 #include "absl/container/flat_hash_set.h" 45 #include "absl/strings/str_cat.h" 46 #include "absl/strings/str_join.h" 47 #include "tensorflow/compiler/xla/status.h" 48 #include "tensorflow/compiler/xla/statusor.h" 49 #include "tensorflow/compiler/xla/util.h" 50 #include "tensorflow/core/platform/errors.h" 51 #include "tensorflow/core/platform/logging.h" 52 #include "tensorflow/core/platform/statusor.h" 53 54 namespace xla { 55 56 // A class for sorting an unordered container of pointers according to the sort 57 // order of an ordered container of pointers. Sorting is stable. 58 // 59 // Terminology: 60 // - unmapped element: An element from the unordered container that does not 61 // have a corresponding element in the ordered container. 62 template <typename PointedToTy> 63 class MappedPtrContainerSorter { 64 public: 65 // A function to map elements from an ordered container to elements in an 66 // unordered container. Not every element in ordered_container need map to an 67 // element in unordered_container and vice versa. 68 using MapPtrFn = std::function<const PointedToTy*(const PointedToTy*)>; 69 70 // A function that maps unmapped elements (from an unordered container) to an 71 // index in the final sorted result. The returned index indicates that the 72 // unmapped element should be placed just after the mapped element at that 73 // index, in the result without unmapped elements. See 74 // IndexBeforeMappedElementsFn() and IndexAfterMappedElementsFn() for how to 75 // indicate that an unmapped element should be placed before or after all 76 // mapped elements, respectively. Unmapped elements destined for the same 77 // index will retain their order from the unordered container. 78 using UnmappedPtrIndexFn = std::function<size_t(const PointedToTy*)>; 79 80 // Functions that return an UnmappedElementIndexFn that indicates that 81 // ummapped elements (from an unordered container) should be placed before or 82 // after all mapped elements, respectively. 83 static const UnmappedPtrIndexFn& IndexBeforeMappedElementsFn(); 84 static const UnmappedPtrIndexFn& IndexAfterMappedElementsFn(); 85 86 // Returned function always returns an error. 87 static const UnmappedPtrIndexFn& InvalidIndexFn(); 88 89 // Sorts an unordered container of pointers according to the order of an 90 // ordered container of pointers. Sorting is stable. Works with POD pointers, 91 // const POD pointers, and unique_ptrs. If an error is returned, 92 // unordered_container is not modified. Returns an error status if: 93 // - unmapped_index() returns an invalid index 94 // - An internal error occurs. (This should theoretically not happen.) 95 template <typename OrderedTy, typename UnorderedTy> 96 static Status Sort(const MapPtrFn& map_ptr, 97 const UnmappedPtrIndexFn& unmapped_index, 98 const OrderedTy& ordered_container, 99 UnorderedTy& unordered_container); 100 101 private: 102 // A class for sorting the indices of the unordered_container. 103 class SortedIndices { 104 public: 105 // max_partial_order_exclusive is 1 greater than the maximum partial order 106 // value allowed to be sent to AddMappedElement(). SortedIndices(size_t max_partial_order_exclusive,size_t unordered_container_size)107 SortedIndices(size_t max_partial_order_exclusive, 108 size_t unordered_container_size) 109 : max_partial_order_exclusive_(max_partial_order_exclusive), 110 unordered_container_size_(unordered_container_size), 111 mapped_element_indices_by_partial_order_( 112 max_partial_order_exclusive) {} 113 114 // Specify the partial ordering value of a mapped element from the 115 // unordered container. The partial ordering is amongst other mapped 116 // elements. 117 Status AddMappedElement(size_t unordered_container_index, 118 size_t partial_order); 119 120 // Specify the index (amongst mapped elements), where an unmapped element 121 // should be inserted. The unmapped element is inserted just after the 122 // mapped element with index target_index_amongst_mapped_elements. 123 void AddUnmappedElement(size_t unordered_container_index, 124 size_t target_index_amongst_mapped_elements); 125 126 std::string ToString() const; 127 128 // The result maps each element in the unordered_container to the target 129 // index that it will occupy in the sorted result. 130 StatusOr<std::vector<size_t>> Flatten() const; 131 132 private: 133 SortedIndices() = delete; 134 135 size_t max_partial_order_exclusive_; 136 size_t unordered_container_size_; 137 std::vector<std::vector<size_t>> mapped_element_indices_by_partial_order_; 138 absl::flat_hash_map<size_t, std::vector<size_t>> 139 target_index_to_unmapped_element_index_; 140 }; 141 IndexBeforeMappedElements()142 static size_t IndexBeforeMappedElements() { 143 return std::numeric_limits<size_t>::max() - 2; 144 } 145 IndexAfterMappedElements()146 static size_t IndexAfterMappedElements() { 147 return std::numeric_limits<size_t>::max() - 1; 148 } 149 InvalidIndex()150 static size_t InvalidIndex() { return std::numeric_limits<size_t>::max(); } 151 152 // Returns a mapping in which the element at index i indicates the target 153 // index that unordered_container[i] should occupy in the sorted result. 154 template <typename OrderedTy, typename UnorderedTy> 155 static StatusOr<std::vector<size_t>> ComputeNewIndices( 156 const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index, 157 const OrderedTy& ordered_container, 158 const UnorderedTy& unordered_container); 159 160 // Reorders unordered_container according to the indices in new_indices. See 161 // ComputeNewIndices() for how to interpret new_indices. 162 template <typename UnorderedTy> 163 static void Reorder(std::vector<size_t> new_indices, 164 UnorderedTy& unordered_container); 165 }; 166 167 ///// Template implementation below ///// 168 169 namespace mapped_ptr_container_sorter_internal { 170 171 template <typename I, typename O> 172 struct PtrGetter { 173 // Extracts a pointer of type O from i. 174 static O Get(I i); 175 }; 176 177 template <typename T> 178 struct PtrGetter<T* const&, const T*> { 179 static const T* Get(T* const& p) { return p; } 180 }; 181 182 template <typename T> 183 struct PtrGetter<T const* const&, const T*> { 184 static const T* Get(T const* const& p) { return p; } 185 }; 186 187 template <typename T> 188 struct PtrGetter<T*&, T*> { 189 static T* Get(T*& p) { return p; } 190 }; 191 192 template <typename T> 193 struct PtrGetter<const std::unique_ptr<T>&, const T*> { 194 static const T* Get(const std::unique_ptr<T>& p) { return p.get(); } 195 }; 196 197 template <typename T> 198 struct PtrGetter<std::unique_ptr<T>&, T*> { 199 static T* Get(std::unique_ptr<T>& p) { return p.get(); } 200 }; 201 202 } // namespace mapped_ptr_container_sorter_internal 203 204 template <typename PointedToTy> 205 const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn& 206 MappedPtrContainerSorter<PointedToTy>::IndexBeforeMappedElementsFn() { 207 static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn( 208 [](const PointedToTy*) { return IndexBeforeMappedElements(); }); 209 return *fn; 210 } 211 212 template <typename PointedToTy> 213 const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn& 214 MappedPtrContainerSorter<PointedToTy>::IndexAfterMappedElementsFn() { 215 static const UnmappedPtrIndexFn* fn = new UnmappedPtrIndexFn( 216 [](const PointedToTy*) { return IndexAfterMappedElements(); }); 217 return *fn; 218 } 219 220 template <typename PointedToTy> 221 const typename MappedPtrContainerSorter<PointedToTy>::UnmappedPtrIndexFn& 222 MappedPtrContainerSorter<PointedToTy>::InvalidIndexFn() { 223 static const UnmappedPtrIndexFn* fn = 224 new UnmappedPtrIndexFn([](const PointedToTy*) { return InvalidIndex(); }); 225 return *fn; 226 } 227 228 template <typename PointedToTy> 229 Status MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddMappedElement( 230 size_t unordered_container_index, size_t partial_order) { 231 if (partial_order >= mapped_element_indices_by_partial_order_.size()) { 232 return InternalErrorStrCat( 233 "invalid partial order: ", partial_order, " v max(", 234 mapped_element_indices_by_partial_order_.size(), ")"); 235 } 236 237 mapped_element_indices_by_partial_order_[partial_order].push_back( 238 unordered_container_index); 239 return Status::OK(); 240 } 241 242 template <typename PointedToTy> 243 void MappedPtrContainerSorter<PointedToTy>::SortedIndices::AddUnmappedElement( 244 size_t unordered_container_index, 245 size_t target_index_amongst_mapped_elements) { 246 target_index_to_unmapped_element_index_[target_index_amongst_mapped_elements] 247 .push_back(unordered_container_index); 248 } 249 250 template <typename PointedToTy> 251 std::string MappedPtrContainerSorter<PointedToTy>::SortedIndices::ToString() 252 const { 253 std::vector<std::string> mapped_element_strs; 254 mapped_element_strs.reserve(mapped_element_indices_by_partial_order_.size()); 255 for (const auto& indices : mapped_element_indices_by_partial_order_) { 256 mapped_element_strs.push_back( 257 absl::StrCat("[", absl::StrJoin(indices, ", "), "]")); 258 } 259 std::vector<std::string> unmapped_element_strs; 260 unmapped_element_strs.reserve(target_index_to_unmapped_element_index_.size()); 261 for (const auto& kv : target_index_to_unmapped_element_index_) { 262 std::string key = absl::StrCat(kv.first); 263 if (kv.first == IndexBeforeMappedElements()) { 264 key = "before_mapped"; 265 } 266 if (kv.first == IndexAfterMappedElements()) { 267 key = "after_mapped"; 268 } 269 if (kv.first == InvalidIndex()) { 270 key = "invalid"; 271 } 272 unmapped_element_strs.push_back( 273 absl::StrCat(key, ": [", absl::StrJoin(kv.second, ", "), "]")); 274 } 275 276 return absl::StrCat( 277 "max_partial_order_exclusive_: ", max_partial_order_exclusive_, "\n", 278 "unordered_container_size_: ", unordered_container_size_, "\n", 279 "mapped_element_indices_by_partial_order_: [", 280 absl::StrJoin(mapped_element_strs, ", "), "]\n", 281 "target_index_to_unmapped_element_index_: {", 282 absl::StrJoin(unmapped_element_strs, ", "), "}\n"); 283 } 284 285 template <typename PointedToTy> 286 StatusOr<std::vector<size_t>> 287 MappedPtrContainerSorter<PointedToTy>::SortedIndices::Flatten() const { 288 std::vector<size_t> result(unordered_container_size_, InvalidIndex()); 289 size_t next_available_index = 0; 290 auto next_index_fn = [&]() -> StatusOr<size_t> { 291 if (next_available_index >= unordered_container_size_) { 292 return InternalErrorStrCat( 293 "invalid unordered_container index: ", next_available_index, 294 " v size(", unordered_container_size_, ")"); 295 } 296 return next_available_index++; 297 }; 298 299 if (target_index_to_unmapped_element_index_.contains( 300 IndexBeforeMappedElements())) { 301 const auto& indices = 302 target_index_to_unmapped_element_index_.at(IndexBeforeMappedElements()); 303 for (size_t index : indices) { 304 TF_ASSIGN_OR_RETURN(result[index], next_index_fn()); 305 } 306 } 307 size_t num_inserted_mapped_elements = 0; 308 for (const auto& mapped_element_indices : 309 mapped_element_indices_by_partial_order_) { 310 for (size_t mapped_element_index : mapped_element_indices) { 311 TF_ASSIGN_OR_RETURN(result[mapped_element_index], next_index_fn()); 312 ++num_inserted_mapped_elements; 313 if (target_index_to_unmapped_element_index_.contains( 314 num_inserted_mapped_elements - 1)) { 315 const auto& unmapped_element_indices = 316 target_index_to_unmapped_element_index_.at( 317 num_inserted_mapped_elements - 1); 318 for (size_t unmapped_element_index : unmapped_element_indices) { 319 TF_ASSIGN_OR_RETURN(result[unmapped_element_index], next_index_fn()); 320 } 321 } 322 } 323 } 324 if (target_index_to_unmapped_element_index_.contains( 325 IndexAfterMappedElements())) { 326 const auto& indices = 327 target_index_to_unmapped_element_index_.at(IndexAfterMappedElements()); 328 for (size_t index : indices) { 329 TF_ASSIGN_OR_RETURN(result[index], next_index_fn()); 330 } 331 } 332 333 // Ensure that every element in unordered_container has a valid new index. 334 absl::flat_hash_set<size_t> used_indices; 335 for (size_t index : result) { 336 if (used_indices.contains(index)) { 337 return InternalErrorStrCat( 338 "2 elements in unordered_container are destined for the same " 339 "index: ", 340 index); 341 } 342 if (index >= unordered_container_size_) { 343 return InvalidArgumentStrCat("invalid unordered_container index: ", index, 344 " v size(", unordered_container_size_, ")"); 345 } 346 } 347 348 return result; 349 } 350 351 template <typename PointedToTy> 352 template <typename OrderedTy, typename UnorderedTy> 353 StatusOr<std::vector<size_t>> 354 MappedPtrContainerSorter<PointedToTy>::ComputeNewIndices( 355 const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index, 356 const OrderedTy& ordered_container, 357 const UnorderedTy& unordered_container) { 358 using UnorderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter< 359 typename UnorderedTy::const_reference, const PointedToTy*>; 360 using OrderedPtrGetter = mapped_ptr_container_sorter_internal::PtrGetter< 361 typename OrderedTy::const_reference, const PointedToTy*>; 362 363 if (unordered_container.size() >= IndexBeforeMappedElements()) { 364 return InvalidArgumentStrCat("Unordered container is too large to sort."); 365 } 366 367 // Step 1: build a set of the ptrs in unordered_container 368 absl::flat_hash_set<const PointedToTy*> unordered_ptrs; 369 for (const auto& unordered_element : unordered_container) { 370 const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_element); 371 unordered_ptrs.insert(ptr); 372 } 373 374 // Step 2: for mapped elements (in unordered_container), create a map from 375 // mapped ptr -> partial ordering 376 absl::flat_hash_map<const PointedToTy*, std::list<size_t>> 377 mapped_ptr_to_partial_order; 378 size_t next_partial_order_value = 0; 379 for (const auto& ordered_element : ordered_container) { 380 const PointedToTy* ordered_ptr = OrderedPtrGetter::Get(ordered_element); 381 const PointedToTy* unordered_ptr = map_ptr(ordered_ptr); 382 if (!unordered_ptr) { 383 // A corresponding unordered element does not exist. 384 continue; 385 } 386 if (!unordered_ptrs.contains(unordered_ptr)) { 387 // A pointer exists that maps to the ordered element, but it's not in our 388 // unordered_container. 389 continue; 390 } 391 mapped_ptr_to_partial_order[unordered_ptr].push_back( 392 next_partial_order_value); 393 ++next_partial_order_value; 394 } 395 396 // Step 3: create sorted unordered element indices 397 SortedIndices result(next_partial_order_value, unordered_container.size()); 398 for (size_t i = 0; i < unordered_container.size(); ++i) { 399 const PointedToTy* ptr = UnorderedPtrGetter::Get(unordered_container[i]); 400 if (!mapped_ptr_to_partial_order.contains(ptr)) { 401 // ptr is unmapped 402 result.AddUnmappedElement(i, unmapped_index(ptr)); 403 continue; 404 } 405 406 // ptr is mapped 407 // 408 // Potentially, several elements in ordered_container map to ptr. 409 // We assign ptr theindex corresponding to the next such ordered element. 410 auto& index_list = mapped_ptr_to_partial_order[ptr]; 411 TF_RETURN_IF_ERROR(result.AddMappedElement(i, index_list.front())); 412 // Do not map more than one unordered element to the same index, unless we 413 // have no choice. 414 if (index_list.size() > 1) { 415 // We never remove the last ordered index, in case ptr appears in the 416 // unordered_container more times than the ordered container. 417 index_list.pop_front(); 418 } 419 } 420 421 VLOG(5) << "Pre flatten unordered_container result:\n" << result.ToString(); 422 return result.Flatten(); 423 } 424 425 template <typename PointedToTy> 426 template <typename UnorderedTy> 427 void MappedPtrContainerSorter<PointedToTy>::Reorder( 428 std::vector<size_t> new_indices, UnorderedTy& unordered_container) { 429 size_t old_pos = 0; 430 while (old_pos < new_indices.size()) { 431 size_t new_pos = new_indices[old_pos]; 432 if (old_pos == new_pos) { 433 ++old_pos; 434 continue; 435 } 436 std::swap(new_indices[old_pos], new_indices[new_pos]); 437 std::swap(unordered_container[old_pos], unordered_container[new_pos]); 438 } 439 } 440 441 template <typename PointedToTy> 442 template <typename OrderedTy, typename UnorderedTy> 443 Status MappedPtrContainerSorter<PointedToTy>::Sort( 444 const MapPtrFn& map_ptr, const UnmappedPtrIndexFn& unmapped_index, 445 const OrderedTy& ordered_container, UnorderedTy& unordered_container) { 446 std::vector<size_t> indices; 447 TF_ASSIGN_OR_RETURN( 448 indices, ComputeNewIndices(map_ptr, unmapped_index, ordered_container, 449 unordered_container)); 450 Reorder(std::move(indices), unordered_container); 451 return Status::OK(); 452 } 453 454 } // namespace xla 455 456 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAPPED_PTR_CONTAINER_SORTER_H_ 457