xref: /aosp_15_r20/external/pytorch/test/quantization/ao_migration/common.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import importlib
2from typing import List, Optional
3
4from torch.testing._internal.common_utils import TestCase
5
6
7class AOMigrationTestCase(TestCase):
8    def _test_function_import(
9        self,
10        package_name: str,
11        function_list: List[str],
12        base: Optional[str] = None,
13        new_package_name: Optional[str] = None,
14    ):
15        r"""Tests individual function list import by comparing the functions
16        and their hashes."""
17        if base is None:
18            base = "quantization"
19        old_base = "torch." + base
20        new_base = "torch.ao." + base
21        if new_package_name is None:
22            new_package_name = package_name
23        old_location = importlib.import_module(f"{old_base}.{package_name}")
24        new_location = importlib.import_module(f"{new_base}.{new_package_name}")
25        for fn_name in function_list:
26            old_function = getattr(old_location, fn_name)
27            new_function = getattr(new_location, fn_name)
28            assert old_function == new_function, f"Functions don't match: {fn_name}"
29            assert hash(old_function) == hash(new_function), (
30                f"Hashes don't match: {old_function}({hash(old_function)}) vs. "
31                f"{new_function}({hash(new_function)})"
32            )
33
34    def _test_dict_import(
35        self, package_name: str, dict_list: List[str], base: Optional[str] = None
36    ):
37        r"""Tests individual function list import by comparing the functions
38        and their hashes."""
39        if base is None:
40            base = "quantization"
41        old_base = "torch." + base
42        new_base = "torch.ao." + base
43        old_location = importlib.import_module(f"{old_base}.{package_name}")
44        new_location = importlib.import_module(f"{new_base}.{package_name}")
45        for dict_name in dict_list:
46            old_dict = getattr(old_location, dict_name)
47            new_dict = getattr(new_location, dict_name)
48            assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
49            for key in new_dict.keys():
50                assert (
51                    old_dict[key] == new_dict[key]
52                ), f"Dicts don't match: {dict_name} for key {key}"
53