xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/permutation_util.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 
7 #include <vector>
8 
9 namespace torch {
10 namespace lazy {
11 
12 TORCH_API std::vector<int64_t> InversePermutation(
13     c10::ArrayRef<int64_t> input_permutation);
14 
15 TORCH_API bool IsPermutation(c10::ArrayRef<int64_t> permutation);
16 
17 // Gathers the input using the order specified by the permutation. For each i,
18 // output[i] = dimensions[permutation[i]]. The given permutation must be the
19 // same size as the input.
20 template <typename Container>
PermuteDimensions(c10::ArrayRef<int64_t> permutation,const Container & dimensions)21 std::vector<typename Container::value_type> PermuteDimensions(
22     c10::ArrayRef<int64_t> permutation,
23     const Container& dimensions) {
24   using T = typename Container::value_type;
25   TORCH_CHECK(
26       dimensions.size() == permutation.size(),
27       "Invalid permutation specified. dimensions.size() != permutation.size()  (",
28       dimensions.size(),
29       " vs. ",
30       permutation.size(),
31       ")");
32   TORCH_CHECK(
33       IsPermutation(permutation),
34       "Invalid permutation specified. Permutation is not permutation");
35   std::vector<T> output(dimensions.size());
36   for (const auto i : c10::irange(permutation.size())) {
37     output[i] = dimensions[permutation[i]];
38   }
39   return output;
40 }
41 
42 } // namespace lazy
43 } // namespace torch
44