1# Owner(s): ["oncall: export"] 2from collections import OrderedDict 3 4import torch 5from torch._dynamo.test_case import TestCase 6from torch.export._tree_utils import is_equivalent, reorder_kwargs 7from torch.testing._internal.common_utils import run_tests 8from torch.utils._pytree import tree_structure 9 10 11class TestTreeUtils(TestCase): 12 def test_reorder_kwargs(self): 13 original_kwargs = {"a": torch.tensor(0), "b": torch.tensor(1)} 14 user_kwargs = {"b": torch.tensor(2), "a": torch.tensor(3)} 15 orig_spec = tree_structure(((), original_kwargs)) 16 17 reordered_kwargs = reorder_kwargs(user_kwargs, orig_spec) 18 19 # Key ordering should be the same 20 self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]), 21 self.assertEqual(reordered_kwargs.popitem()[0], original_kwargs.popitem()[0]), 22 23 def test_equivalence_check(self): 24 tree1 = {"a": torch.tensor(0), "b": torch.tensor(1), "c": None} 25 tree2 = OrderedDict(a=torch.tensor(0), b=torch.tensor(1), c=None) 26 spec1 = tree_structure(tree1) 27 spec2 = tree_structure(tree2) 28 29 def dict_ordered_dict_eq(type1, context1, type2, context2): 30 if type1 is None or type2 is None: 31 return type1 is type2 and context1 == context2 32 33 if issubclass(type1, (dict, OrderedDict)) and issubclass( 34 type2, (dict, OrderedDict) 35 ): 36 return context1 == context2 37 38 return type1 is type2 and context1 == context2 39 40 self.assertTrue(is_equivalent(spec1, spec2, dict_ordered_dict_eq)) 41 42 # Wrong ordering should still fail 43 tree3 = OrderedDict(b=torch.tensor(1), a=torch.tensor(0)) 44 spec3 = tree_structure(tree3) 45 self.assertFalse(is_equivalent(spec1, spec3, dict_ordered_dict_eq)) 46 47 48if __name__ == "__main__": 49 run_tests() 50