1import tempfile 2import unittest 3from unittest.mock import NonCallableMock, patch 4 5import tools.jit.gen_unboxing as gen_unboxing 6 7 8@patch("tools.jit.gen_unboxing.get_custom_build_selector") 9@patch("tools.jit.gen_unboxing.parse_native_yaml") 10@patch("tools.jit.gen_unboxing.make_file_manager") 11@patch("tools.jit.gen_unboxing.gen_unboxing") 12class TestGenUnboxing(unittest.TestCase): 13 def test_get_custom_build_selector_with_allowlist( 14 self, 15 mock_gen_unboxing: NonCallableMock, 16 mock_make_file_manager: NonCallableMock, 17 mock_parse_native_yaml: NonCallableMock, 18 mock_get_custom_build_selector: NonCallableMock, 19 ) -> None: 20 args = ["--op-registration-allowlist=op1", "--op-selection-yaml-path=path2"] 21 gen_unboxing.main(args) 22 mock_get_custom_build_selector.assert_called_once_with(["op1"], "path2") 23 24 def test_get_custom_build_selector_with_allowlist_yaml( 25 self, 26 mock_gen_unboxing: NonCallableMock, 27 mock_make_file_manager: NonCallableMock, 28 mock_parse_native_yaml: NonCallableMock, 29 mock_get_custom_build_selector: NonCallableMock, 30 ) -> None: 31 temp_file = tempfile.NamedTemporaryFile() 32 temp_file.write(b"- aten::add.Tensor") 33 temp_file.seek(0) 34 args = [ 35 f"--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}", 36 "--op-selection-yaml-path=path2", 37 ] 38 gen_unboxing.main(args) 39 mock_get_custom_build_selector.assert_called_once_with( 40 ["aten::add.Tensor"], "path2" 41 ) 42 temp_file.close() 43 44 def test_get_custom_build_selector_with_both_allowlist_and_yaml( 45 self, 46 mock_gen_unboxing: NonCallableMock, 47 mock_make_file_manager: NonCallableMock, 48 mock_parse_native_yaml: NonCallableMock, 49 mock_get_custom_build_selector: NonCallableMock, 50 ) -> None: 51 temp_file = tempfile.NamedTemporaryFile() 52 temp_file.write(b"- aten::add.Tensor") 53 temp_file.seek(0) 54 args = [ 55 "--op-registration-allowlist=op1", 56 "--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}", 57 "--op-selection-yaml-path=path2", 58 ] 59 gen_unboxing.main(args) 60 mock_get_custom_build_selector.assert_called_once_with(["op1"], "path2") 61 temp_file.close() 62 63 64if __name__ == "__main__": 65 unittest.main() 66