xref: /aosp_15_r20/external/pytorch/test/export/test_tree_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2from collections import OrderedDict
3
4import torch
5from torch._dynamo.test_case import TestCase
6from torch.export._tree_utils import is_equivalent, reorder_kwargs
7from torch.testing._internal.common_utils import run_tests
8from torch.utils._pytree import tree_structure
9
10
11class TestTreeUtils(TestCase):
12    def test_reorder_kwargs(self):
13        original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)}
14        user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)}
15        orig_spec = tree_structure(((), original_kwargs))
16
17        reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec)
18
19        # Key ordering should be the same
20        self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
21        self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]),
22
23    def test_equivalence_check(self):
24        tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None}
25        tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None)
26        spec1 = tree_structure(tree1)
27        spec2 = tree_structure(tree2)
28
29        def dict_ordered_dict_eq(type1, context1, type2, context2):
30            if type1 is None or type2 is None:
31                return type1 is type2 and context1 == context2
32
33            if issubclass(type1, (dict, OrderedDict)) and issubclass(
34                type2, (dict, OrderedDict)
35            ):
36                return context1 == context2
37
38            return type1 is type2 and context1 == context2
39
40        self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq))
41
42        # Wrong ordering should still fail
43        tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0))
44        spec3 = tree_structure(tree3)
45        self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq))
46
47
48if __name__ == "__main__":
49    run_tests()
50