xref: /aosp_15_r20/external/pytorch/test/functorch/test_rearrange.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: functorch"]
2
3"""Adapted from https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/tests/test_ops.py.
4
5MIT License
6
7Copyright (c) 2018 Alex Rogozhnikov
8
9Permission is hereby granted, free of charge, to any person obtaining a copy
10of this software and associated documentation files (the "Software"), to deal
11in the Software without restriction, including without limitation the rights
12to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13copies of the Software, and to permit persons to whom the Software is
14furnished to do so, subject to the following conditions:
15
16The above copyright notice and this permission notice shall be included in all
17copies or substantial portions of the Software.
18
19THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25SOFTWARE.
26"""
27
28from typing import List, Tuple
29
30import numpy as np
31
32import torch
33from functorch.einops import rearrange
34from torch.testing._internal.common_utils import run_tests, TestCase
35
36
37identity_patterns: List[str] = [
38    "...->...",
39    "a b c d e-> a b c d e",
40    "a b c d e ...-> ... a b c d e",
41    "a b c d e ...-> a ... b c d e",
42    "... a b c d e -> ... a b c d e",
43    "a ... e-> a ... e",
44    "a ... -> a ... ",
45    "a ... c d e -> a (...) c d e",
46]
47
48equivalent_rearrange_patterns: List[Tuple[str, str]] = [
49    ("a b c d e -> (a b) c d e", "a b ... -> (a b) ... "),
50    ("a b c d e -> a b (c d) e", "... c d e -> ... (c d) e"),
51    ("a b c d e -> a b c d e", "... -> ... "),
52    ("a b c d e -> (a b c d e)", "... ->  (...)"),
53    ("a b c d e -> b (c d e) a", "a b ... -> b (...) a"),
54    ("a b c d e -> b (a c d) e", "a b ... e -> b (a ...) e"),
55]
56
57
58class TestRearrange(TestCase):
59    def test_collapsed_ellipsis_errors_out(self) -> None:
60        x = torch.zeros([1, 1, 1, 1, 1])
61        rearrange(x, "a b c d ... ->  a b c ... d")
62        with self.assertRaises(ValueError):
63            rearrange(x, "a b c d (...) ->  a b c ... d")
64
65        rearrange(x, "... ->  (...)")
66        with self.assertRaises(ValueError):
67            rearrange(x, "(...) -> (...)")
68
69    def test_ellipsis_ops(self) -> None:
70        x = torch.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
71        for pattern in identity_patterns:
72            torch.testing.assert_close(rearrange(x, pattern), x, msg=pattern)
73
74        for pattern1, pattern2 in equivalent_rearrange_patterns:
75            torch.testing.assert_close(
76                rearrange(x, pattern1),
77                rearrange(x, pattern2),
78                msg=f"{pattern1} vs {pattern2}",
79            )
80
81    def test_rearrange_consistency(self) -> None:
82        shape = [1, 2, 3, 5, 7, 11]
83        x = torch.arange(int(np.prod(shape, dtype=int))).reshape(shape)
84        for pattern in [
85            "a b c d e f -> a b c d e f",
86            "b a c d e f -> a b d e f c",
87            "a b c d e f -> f e d c b a",
88            "a b c d e f -> (f e) d (c b a)",
89            "a b c d e f -> (f e d c b a)",
90        ]:
91            result = rearrange(x, pattern)
92            self.assertEqual(len(np.setdiff1d(x, result)), 0)
93            self.assertIs(result.dtype, x.dtype)
94
95        result = rearrange(x, "a b c d e f -> a (b) (c d e) f")
96        torch.testing.assert_close(x.flatten(), result.flatten())
97
98        result = rearrange(x, "a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11")
99        torch.testing.assert_close(x, result)
100
101        result1 = rearrange(x, "a b c d e f -> f e d c b a")
102        result2 = rearrange(x, "f e d c b a -> a b c d e f")
103        torch.testing.assert_close(result1, result2)
104
105        result = rearrange(
106            rearrange(x, "a b c d e f -> (f d) c (e b) a"),
107            "(f d) c (e b) a -> a b c d e f",
108            b=2,
109            d=5,
110        )
111        torch.testing.assert_close(x, result)
112
113        sizes = dict(zip("abcdef", shape))
114        temp = rearrange(x, "a b c d e f -> (f d) c (e b) a", **sizes)
115        result = rearrange(temp, "(f d) c (e b) a -> a b c d e f", **sizes)
116        torch.testing.assert_close(x, result)
117
118        x2 = torch.arange(2 * 3 * 4).reshape([2, 3, 4])
119        result = rearrange(x2, "a b c -> b c a")
120        self.assertEqual(x2[1, 2, 3], result[2, 3, 1])
121        self.assertEqual(x2[0, 1, 2], result[1, 2, 0])
122
123    def test_rearrange_permutations(self) -> None:
124        # tests random permutation of axes against two independent numpy ways
125        for n_axes in range(1, 10):
126            input = torch.arange(2**n_axes).reshape([2] * n_axes)
127            permutation = np.random.permutation(n_axes)
128            left_expression = " ".join("i" + str(axis) for axis in range(n_axes))
129            right_expression = " ".join("i" + str(axis) for axis in permutation)
130            expression = left_expression + " -> " + right_expression
131            result = rearrange(input, expression)
132
133            for pick in np.random.randint(0, 2, [10, n_axes]):
134                self.assertEqual(input[tuple(pick)], result[tuple(pick[permutation])])
135
136        for n_axes in range(1, 10):
137            input = torch.arange(2**n_axes).reshape([2] * n_axes)
138            permutation = np.random.permutation(n_axes)
139            left_expression = " ".join("i" + str(axis) for axis in range(n_axes)[::-1])
140            right_expression = " ".join("i" + str(axis) for axis in permutation[::-1])
141            expression = left_expression + " -> " + right_expression
142            result = rearrange(input, expression)
143            self.assertEqual(result.shape, input.shape)
144            expected_result = torch.zeros_like(input)
145            for original_axis, result_axis in enumerate(permutation):
146                expected_result |= ((input >> original_axis) & 1) << result_axis
147
148            torch.testing.assert_close(result, expected_result)
149
150    def test_concatenations_and_stacking(self) -> None:
151        for n_arrays in [1, 2, 5]:
152            shapes: List[List[int]] = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
153            for shape in shapes:
154                arrays1 = [
155                    torch.arange(i, i + np.prod(shape, dtype=int)).reshape(shape)
156                    for i in range(n_arrays)
157                ]
158                result0 = torch.stack(arrays1)
159                result1 = rearrange(arrays1, "...->...")
160                torch.testing.assert_close(result0, result1)
161
162    def test_unsqueeze(self) -> None:
163        x = torch.randn((2, 3, 4, 5))
164        actual = rearrange(x, "b h w c -> b 1 h w 1 c")
165        expected = x.unsqueeze(1).unsqueeze(-2)
166        torch.testing.assert_close(actual, expected)
167
168    def test_squeeze(self) -> None:
169        x = torch.randn((2, 1, 3, 4, 1, 5))
170        actual = rearrange(x, "b 1 h w 1 c -> b h w c")
171        expected = x.squeeze()
172        torch.testing.assert_close(actual, expected)
173
174    def test_0_dim_tensor(self) -> None:
175        x = expected = torch.tensor(1)
176        actual = rearrange(x, "->")
177        torch.testing.assert_close(actual, expected)
178
179        actual = rearrange(x, "... -> ...")
180        torch.testing.assert_close(actual, expected)
181
182    def test_dimension_mismatch_no_ellipsis(self) -> None:
183        x = torch.randn((1, 2, 3))
184        with self.assertRaises(ValueError):
185            rearrange(x, "a b -> b a")
186
187        with self.assertRaises(ValueError):
188            rearrange(x, "a b c d -> c d b a")
189
190    def test_dimension_mismatch_with_ellipsis(self) -> None:
191        x = torch.tensor(1)
192        with self.assertRaises(ValueError):
193            rearrange(x, "a ... -> ... a")
194
195
196if __name__ == "__main__":
197    run_tests()
198