1 #pragma once 2 3 #include <c10/util/ArrayRef.h> 4 #include <c10/util/Exception.h> 5 #include <c10/util/irange.h> 6 7 #include <vector> 8 9 namespace torch { 10 namespace lazy { 11 12 TORCH_API std::vector<int64_t> InversePermutation( 13 c10::ArrayRef<int64_t> input_permutation); 14 15 TORCH_API bool IsPermutation(c10::ArrayRef<int64_t> permutation); 16 17 // Gathers the input using the order specified by the permutation. For each i, 18 // output[i] = dimensions[permutation[i]]. The given permutation must be the 19 // same size as the input. 20 template <typename Container> PermuteDimensions(c10::ArrayRef<int64_t> permutation,const Container & dimensions)21std::vector<typename Container::value_type> PermuteDimensions( 22 c10::ArrayRef<int64_t> permutation, 23 const Container& dimensions) { 24 using T = typename Container::value_type; 25 TORCH_CHECK( 26 dimensions.size() == permutation.size(), 27 "Invalid permutation specified. dimensions.size() != permutation.size() (", 28 dimensions.size(), 29 " vs. ", 30 permutation.size(), 31 ")"); 32 TORCH_CHECK( 33 IsPermutation(permutation), 34 "Invalid permutation specified. Permutation is not permutation"); 35 std::vector<T> output(dimensions.size()); 36 for (const auto i : c10::irange(permutation.size())) { 37 output[i] = dimensions[permutation[i]]; 38 } 39 return output; 40 } 41 42 } // namespace lazy 43 } // namespace torch 44