xref: /aosp_15_r20/external/executorch/extension/pytree/test/test.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8from collections import namedtuple
9from typing import Any, Dict
10
11import torch
12
13# @manual=//executorch/extension/pytree:pybindings
14from executorch.extension.pytree import (
15    broadcast_to_and_flatten,
16    register_custom,
17    tree_flatten,
18    tree_map,
19    tree_unflatten,
20    TreeSpec,
21)
22
23
24# pyre-fixme[11]: Annotation `TreeSpec` is not defined as a type.
25def _spec(o: Any) -> TreeSpec:
26    # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
27    _, spec = tree_flatten(o)
28    return spec
29
30
31# Constructs string representation of pytree spec of type specified by type_char (can be 'T' for tuple, 'L' for List) argument, that contains n children, each with single leaf.
32# e.g. ('T', 3) -> 'T3#1#1#1($,$,$)'
33def _spec_str(type_char, n: int) -> str:
34    spec = type_char + str(n)
35    for _ in range(n):
36        spec += "#1"
37    spec += "("
38    for i in range(n):
39        if i > 0:
40            spec += ","
41        spec += "$"
42    spec += ")"
43    return spec
44
45
46# Constructs string representation of pytree spec of Dict, keys can be str or int, every value is leaf.
47# e.g.: {'a': 1, 2: 2} -> D2#1#1('a':$,2:$)
48def _spec_str_dict(d: Dict[Any, Any]) -> str:
49    n = len(d)
50    spec = "D" + str(n)
51    for _ in range(n):
52        spec += "#1"
53    spec += "("
54    i = 0
55    for key in d.keys():
56        if i > 0:
57            spec += ","
58        if isinstance(key, str):
59            spec += "'" + key + "'"
60        else:
61            spec += str(key)
62        spec += ":$"
63        i += 1
64    spec += ")"
65    return spec
66
67
68class TestPytree(unittest.TestCase):
69    def test(self):
70        SPEC = "D4#2#1#2#2('a':L2#1#1($,$),1:$,2:T2#1#1($,$),'str':D2#1#1('str':$,'str2':$))"
71        d = {}
72        d["a"] = [777, 1]
73        d[1] = 4
74        d[2] = ("ta", 2)
75        d["str"] = {"str": 23, "str2": "47str"}
76        (leaves, pytree) = tree_flatten(d)
77        self.assertEqual(leaves, [777, 1, 4, "ta", 2, 23, "47str"])
78        pytree_str = pytree.to_str()
79        self.assertEqual(pytree_str, SPEC)
80
81        leaves_test = []
82        for i in range(len(leaves)):
83            if i % 2 == 0:
84                leaves_test.append(i + 13)
85            else:
86                leaves_test.append(str(i + 13))
87
88        tree_test = pytree.tree_unflatten(leaves_test)
89        self.assertEqual(
90            tree_test,
91            {"a": [13, "14"], 1: 15, 2: ("16", 17), "str": {"str": "18", "str2": 19}},
92        )
93
94        pytree_from = TreeSpec.from_str(SPEC)
95        spec_str_to = pytree_from.to_str()
96        self.assertEqual(SPEC, spec_str_to)
97
98    def test_extract_nested_list(self):
99        nested_struct = (1, 2, [3, 4])
100        (_, pytree) = tree_flatten(nested_struct)
101        self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,L2#1#1($,$))")
102
103    def test_extract_nested_dict(self):
104        nested_struct = (1, 2, {3: 4, "str": 6})
105        (_, pytree) = tree_flatten(nested_struct)
106        self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,D2#1#1(3:$,'str':$))")
107
108    def test_extracted_scalar(self):
109        struct = 4
110        (_, pytree) = tree_flatten(struct)
111        self.assertEqual(pytree.to_str(), "$")
112
113    def test_map(self):
114        struct = (1, 2, [3, 4])
115        struct_map = tree_map(lambda x: 2 * x, struct)
116        self.assertEqual(struct_map, (2, 4, [6, 8]))
117
118    def test_treespec_equality(self):
119        self.assertTrue(TreeSpec.from_str("$") == TreeSpec.from_str("$"))
120        self.assertTrue(_spec([1]) == TreeSpec.from_str("L1#1($)"))
121        self.assertTrue(_spec((1)) != _spec([1]))
122        self.assertTrue(_spec((1)) == _spec((2)))
123
124    def test_flatten_unflatten_leaf(self):
125        def run_test_with_leaf(leaf):
126            values, treespec = tree_flatten(leaf)
127            self.assertEqual(values, [leaf])
128            self.assertEqual(treespec, TreeSpec.from_str("$"))
129
130            unflattened = tree_unflatten(values, treespec)
131            self.assertEqual(unflattened, leaf)
132
133        run_test_with_leaf(1)
134        run_test_with_leaf(1.0)
135        run_test_with_leaf(None)
136        run_test_with_leaf(bool)
137
138    def test_flatten_unflatten_list(self):
139        def run_test(lst):
140            spec = _spec_str("L", len(lst))
141
142            expected_spec = TreeSpec.from_str(spec)
143            values, treespec = tree_flatten(lst)
144            self.assertTrue(isinstance(values, list))
145            self.assertEqual(values, lst)
146            self.assertEqual(treespec, expected_spec)
147
148            unflattened = tree_unflatten(values, treespec)
149            self.assertEqual(unflattened, lst)
150            self.assertTrue(isinstance(unflattened, list))
151
152        run_test([])
153        run_test([1.0, 2])
154        run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11])
155
156    def test_flatten_unflatten_tuple(self):
157        def run_test(tup):
158            spec = _spec_str("T", len(tup))
159
160            expected_spec = TreeSpec.from_str(spec)
161            values, treespec = tree_flatten(tup)
162            self.assertTrue(isinstance(values, list))
163            self.assertEqual(values, list(tup))
164            self.assertEqual(treespec, expected_spec)
165
166            unflattened = tree_unflatten(values, treespec)
167            self.assertEqual(unflattened, tup)
168            self.assertTrue(isinstance(unflattened, tuple))
169
170        run_test(())
171        run_test((1.0,))
172        run_test((1.0, 2))
173        run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11))
174
175    def test_flatten_unflatten_namedtuple(self):
176        Point = namedtuple("Point", ["x", "y"])
177
178        def run_test(tup):
179            spec = _spec_str("N", len(tup))
180            expected_spec = TreeSpec.from_str(spec)
181
182            values, treespec = tree_flatten(tup)
183            self.assertTrue(isinstance(values, list))
184
185            self.assertEqual(values, list(tup))
186            self.assertEqual(treespec, expected_spec)
187
188            unflattened = tree_unflatten(values, treespec)
189            self.assertEqual(unflattened, tup)
190
191        run_test(Point(1.0, 2))
192        run_test(Point(torch.tensor(1.0), 2))
193
194    def test_flatten_unflatten_torch_namedtuple_return_type(self):
195        x = torch.randn(3, 3)
196        expected = torch.max(x, dim=0)
197
198        values, spec = tree_flatten(expected)
199        result = tree_unflatten(values, spec)
200
201        self.assertEqual(type(result), type(expected))
202        self.assertEqual(result, expected)
203
204    def test_flatten_unflatten_dict(self):
205        def run_test(d):
206            spec = _spec_str_dict(d)
207
208            values, treespec = tree_flatten(d)
209            self.assertTrue(isinstance(values, list))
210            self.assertEqual(values, list(d.values()))
211            self.assertEqual(treespec, TreeSpec.from_str(spec))
212
213            unflattened = tree_unflatten(values, treespec)
214            self.assertEqual(unflattened, d)
215            self.assertTrue(isinstance(unflattened, dict))
216
217        run_test({})
218        run_test({"a": 1})
219        run_test({"abcdefg": torch.randn(2, 3)})
220        run_test({1: torch.randn(2, 3)})
221        run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)})
222
223    def test_flatten_unflatten_nested(self):
224        def run_test(pytree):
225            values, treespec = tree_flatten(pytree)
226            self.assertTrue(isinstance(values, list))
227
228            unflattened = tree_unflatten(values, treespec)
229            self.assertEqual(unflattened, pytree)
230
231        cases = [
232            [()],
233            ([],),
234            {"a": ()},
235            {"a": 0, "b": [{"c": 1}]},
236            {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)},
237        ]
238        for case in cases:
239            run_test(case)
240
241    def test_treemap(self):
242        def run_test(pytree):
243            def f(x):
244                return x * 3
245
246            sm1 = sum(map(tree_flatten(pytree)[0], f))
247            sm2 = tree_flatten(tree_map(f, pytree))[0]
248            self.assertEqual(sm1, sm2)
249
250            def invf(x):
251                return x // 3
252
253            self.assertEqual(tree_flatten(tree_flatten(pytree, f), invf), pytree)
254
255            cases = [
256                [()],
257                ([],),
258                {"a": ()},
259                {"a": 1, "b": [{"c": 2}]},
260                {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)},
261            ]
262            for case in cases:
263                run_test(case)
264
265    def test_treespec_repr(self):
266        pytree = (0, [0, 0, 0])
267        _, spec = tree_flatten(pytree)
268        self.assertEqual(repr(spec), "T2#1#3($,L3#1#1#1($,$,$))")
269
270    def test_custom_tree_node(self):
271        class Point(object):
272            def __init__(self, x, y, name):
273                self.x = x
274                self.y = y
275                self.name = name
276
277            def __repr__(self):
278                return "Point(x:{}, y:{}, name: {})".format(self.x, self.y, self.name)
279
280        def custom_flatten(p):
281            children = [p.x, p.y]
282            extra_data = p.name
283            return (children, extra_data)
284
285        def custom_unflatten(children, extra_data):
286            return Point(*children, extra_data)
287
288        register_custom(Point, custom_flatten, custom_unflatten)
289
290        point = Point((1.0, 1.0, 1), 2.0, "point_name")
291        children, spec = tree_flatten(point)
292        point2 = tree_unflatten(children, spec)
293        self.assertEqual(str(point), str(point2))
294
295    def test_broadcast_to_and_flatten(self):
296        cases = [
297            (1, (), []),
298            # Same (flat) structures
299            ((1,), (0,), [1]),
300            ([1], [0], [1]),
301            ((1, 2, 3), (0, 0, 0), [1, 2, 3]),
302            ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]),
303            # Mismatched (flat) structures
304            ([1], (0,), None),
305            ([1], (0,), None),
306            ((1,), [0], None),
307            ((1, 2, 3), (0, 0), None),
308            ({"a": 1, "b": 2}, {"a": 0}, None),
309            ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None),
310            ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None),
311            # Same (nested) structures
312            ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]),
313            ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]),
314            # Mismatched (nested) structures
315            ((1, [2, 3]), (0, (0, 0)), None),
316            ((1, [2, 3]), (0, [0, 0, 0]), None),
317            # Broadcasting single value
318            (1, (0, 0, 0), [1, 1, 1]),
319            (1, [0, 0, 0], [1, 1, 1]),
320            (1, {"a": 0, "b": 0}, [1, 1]),
321            (1, (0, [0, [0]], 0), [1, 1, 1, 1]),
322            (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]),
323            # Broadcast multiple things
324            ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]),
325            ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]),
326            (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]),
327        ]
328        for pytree, to_pytree, expected in cases:
329            _, to_spec = tree_flatten(to_pytree)
330            result = broadcast_to_and_flatten(pytree, to_spec)
331            self.assertEqual(result, expected, msg=str([pytree, to_spec, expected]))
332