1import importlib 2from typing import List, Optional 3 4from torch.testing._internal.common_utils import TestCase 5 6 7class AOMigrationTestCase(TestCase): 8 def _test_function_import( 9 self, 10 package_name: str, 11 function_list: List[str], 12 base: Optional[str] = None, 13 new_package_name: Optional[str] = None, 14 ): 15 r"""Tests individual function list import by comparing the functions 16 and their hashes.""" 17 if base is None: 18 base = "quantization" 19 old_base = "torch." + base 20 new_base = "torch.ao." + base 21 if new_package_name is None: 22 new_package_name = package_name 23 old_location = importlib.import_module(f"{old_base}.{package_name}") 24 new_location = importlib.import_module(f"{new_base}.{new_package_name}") 25 for fn_name in function_list: 26 old_function = getattr(old_location, fn_name) 27 new_function = getattr(new_location, fn_name) 28 assert old_function == new_function, f"Functions don't match: {fn_name}" 29 assert hash(old_function) == hash(new_function), ( 30 f"Hashes don't match: {old_function}({hash(old_function)}) vs. " 31 f"{new_function}({hash(new_function)})" 32 ) 33 34 def _test_dict_import( 35 self, package_name: str, dict_list: List[str], base: Optional[str] = None 36 ): 37 r"""Tests individual function list import by comparing the functions 38 and their hashes.""" 39 if base is None: 40 base = "quantization" 41 old_base = "torch." + base 42 new_base = "torch.ao." + base 43 old_location = importlib.import_module(f"{old_base}.{package_name}") 44 new_location = importlib.import_module(f"{new_base}.{package_name}") 45 for dict_name in dict_list: 46 old_dict = getattr(old_location, dict_name) 47 new_dict = getattr(new_location, dict_name) 48 assert old_dict == new_dict, f"Dicts don't match: {dict_name}" 49 for key in new_dict.keys(): 50 assert ( 51 old_dict[key] == new_dict[key] 52 ), f"Dicts don't match: {dict_name} for key {key}" 53