1# Owner(s): ["module: functorch"] 2 3"""Adapted from https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/tests/test_ops.py. 4 5MIT License 6 7Copyright (c) 2018 Alex Rogozhnikov 8 9Permission is hereby granted, free of charge, to any person obtaining a copy 10of this software and associated documentation files (the "Software"), to deal 11in the Software without restriction, including without limitation the rights 12to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13copies of the Software, and to permit persons to whom the Software is 14furnished to do so, subject to the following conditions: 15 16The above copyright notice and this permission notice shall be included in all 17copies or substantial portions of the Software. 18 19THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25SOFTWARE. 26""" 27 28from typing import List, Tuple 29 30import numpy as np 31 32import torch 33from functorch.einops import rearrange 34from torch.testing._internal.common_utils import run_tests, TestCase 35 36 37identity_patterns: List[str] = [ 38 "...->...", 39 "a b c d e-> a b c d e", 40 "a b c d e ...-> ... a b c d e", 41 "a b c d e ...-> a ... b c d e", 42 "... a b c d e -> ... a b c d e", 43 "a ... e-> a ... e", 44 "a ... -> a ... ", 45 "a ... c d e -> a (...) c d e", 46] 47 48equivalent_rearrange_patterns: List[Tuple[str, str]] = [ 49 ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "), 50 ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"), 51 ("a b c d e -> a b c d e", "... -> ... "), 52 ("a b c d e -> (a b c d e)", "... -> (...)"), 53 ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"), 54 ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"), 55] 56 57 58class TestRearrange(TestCase): 59 def test_collapsed_ellipsis_errors_out(self) -> None: 60 x = torch.zeros([1, 1, 1, 1, 1]) 61 rearrange(x, "a b c d ... -> a b c ... d") 62 with self.assertRaises(ValueError): 63 rearrange(x, "a b c d (...) -> a b c ... d") 64 65 rearrange(x, "... -> (...)") 66 with self.assertRaises(ValueError): 67 rearrange(x, "(...) -> (...)") 68 69 def test_ellipsis_ops(self) -> None: 70 x = torch.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) 71 for pattern in identity_patterns: 72 torch.testing.assert_close(rearrange(x, pattern), x, msg=pattern) 73 74 for pattern1, pattern2 in equivalent_rearrange_patterns: 75 torch.testing.assert_close( 76 rearrange(x, pattern1), 77 rearrange(x, pattern2), 78 msg=f"{pattern1} vs {pattern2}", 79 ) 80 81 def test_rearrange_consistency(self) -> None: 82 shape = [1, 2, 3, 5, 7, 11] 83 x = torch.arange(int(np.prod(shape, dtype=int))).reshape(shape) 84 for pattern in [ 85 "a b c d e f -> a b c d e f", 86 "b a c d e f -> a b d e f c", 87 "a b c d e f -> f e d c b a", 88 "a b c d e f -> (f e) d (c b a)", 89 "a b c d e f -> (f e d c b a)", 90 ]: 91 result = rearrange(x, pattern) 92 self.assertEqual(len(np.setdiff1d(x, result)), 0) 93 self.assertIs(result.dtype, x.dtype) 94 95 result = rearrange(x, "a b c d e f -> a (b) (c d e) f") 96 torch.testing.assert_close(x.flatten(), result.flatten()) 97 98 result = rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11") 99 torch.testing.assert_close(x, result) 100 101 result1 = rearrange(x, "a b c d e f -> f e d c b a") 102 result2 = rearrange(x, "f e d c b a -> a b c d e f") 103 torch.testing.assert_close(result1, result2) 104 105 result = rearrange( 106 rearrange(x, "a b c d e f -> (f d) c (e b) a"), 107 "(f d) c (e b) a -> a b c d e f", 108 b=2, 109 d=5, 110 ) 111 torch.testing.assert_close(x, result) 112 113 sizes = dict(zip("abcdef", shape)) 114 temp = rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes) 115 result = rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes) 116 torch.testing.assert_close(x, result) 117 118 x2 = torch.arange(2 * 3 * 4).reshape([2, 3, 4]) 119 result = rearrange(x2, "a b c -> b c a") 120 self.assertEqual(x2[1, 2, 3], result[2, 3, 1]) 121 self.assertEqual(x2[0, 1, 2], result[1, 2, 0]) 122 123 def test_rearrange_permutations(self) -> None: 124 # tests random permutation of axes against two independent numpy ways 125 for n_axes in range(1, 10): 126 input = torch.arange(2**n_axes).reshape([2] * n_axes) 127 permutation = np.random.permutation(n_axes) 128 left_expression = " ".join("i" + str(axis) for axis in range(n_axes)) 129 right_expression = " ".join("i" + str(axis) for axis in permutation) 130 expression = left_expression + " -> " + right_expression 131 result = rearrange(input, expression) 132 133 for pick in np.random.randint(0, 2, [10, n_axes]): 134 self.assertEqual(input[tuple(pick)], result[tuple(pick[permutation])]) 135 136 for n_axes in range(1, 10): 137 input = torch.arange(2**n_axes).reshape([2] * n_axes) 138 permutation = np.random.permutation(n_axes) 139 left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1]) 140 right_expression = " ".join("i" + str(axis) for axis in permutation[::-1]) 141 expression = left_expression + " -> " + right_expression 142 result = rearrange(input, expression) 143 self.assertEqual(result.shape, input.shape) 144 expected_result = torch.zeros_like(input) 145 for original_axis, result_axis in enumerate(permutation): 146 expected_result |= ((input >> original_axis) & 1) << result_axis 147 148 torch.testing.assert_close(result, expected_result) 149 150 def test_concatenations_and_stacking(self) -> None: 151 for n_arrays in [1, 2, 5]: 152 shapes: List[List[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] 153 for shape in shapes: 154 arrays1 = [ 155 torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape) 156 for i in range(n_arrays) 157 ] 158 result0 = torch.stack(arrays1) 159 result1 = rearrange(arrays1, "...->...") 160 torch.testing.assert_close(result0, result1) 161 162 def test_unsqueeze(self) -> None: 163 x = torch.randn((2, 3, 4, 5)) 164 actual = rearrange(x, "b h w c -> b 1 h w 1 c") 165 expected = x.unsqueeze(1).unsqueeze(-2) 166 torch.testing.assert_close(actual, expected) 167 168 def test_squeeze(self) -> None: 169 x = torch.randn((2, 1, 3, 4, 1, 5)) 170 actual = rearrange(x, "b 1 h w 1 c -> b h w c") 171 expected = x.squeeze() 172 torch.testing.assert_close(actual, expected) 173 174 def test_0_dim_tensor(self) -> None: 175 x = expected = torch.tensor(1) 176 actual = rearrange(x, "->") 177 torch.testing.assert_close(actual, expected) 178 179 actual = rearrange(x, "... -> ...") 180 torch.testing.assert_close(actual, expected) 181 182 def test_dimension_mismatch_no_ellipsis(self) -> None: 183 x = torch.randn((1, 2, 3)) 184 with self.assertRaises(ValueError): 185 rearrange(x, "a b -> b a") 186 187 with self.assertRaises(ValueError): 188 rearrange(x, "a b c d -> c d b a") 189 190 def test_dimension_mismatch_with_ellipsis(self) -> None: 191 x = torch.tensor(1) 192 with self.assertRaises(ValueError): 193 rearrange(x, "a ... -> ... a") 194 195 196if __name__ == "__main__": 197 run_tests() 198