xref: /aosp_15_r20/external/pytorch/tools/jit/test/test_gen_unboxing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import tempfile
2import unittest
3from unittest.mock import NonCallableMock, patch
4
5import tools.jit.gen_unboxing as gen_unboxing
6
7
8@patch("tools.jit.gen_unboxing.get_custom_build_selector")
9@patch("tools.jit.gen_unboxing.parse_native_yaml")
10@patch("tools.jit.gen_unboxing.make_file_manager")
11@patch("tools.jit.gen_unboxing.gen_unboxing")
12class TestGenUnboxing(unittest.TestCase):
13    def test_get_custom_build_selector_with_allowlist(
14        self,
15        mock_gen_unboxing: NonCallableMock,
16        mock_make_file_manager: NonCallableMock,
17        mock_parse_native_yaml: NonCallableMock,
18        mock_get_custom_build_selector: NonCallableMock,
19    ) -> None:
20        args = ["--op-registration-allowlist=op1", "--op-selection-yaml-path=path2"]
21        gen_unboxing.main(args)
22        mock_get_custom_build_selector.assert_called_once_with(["op1"], "path2")
23
24    def test_get_custom_build_selector_with_allowlist_yaml(
25        self,
26        mock_gen_unboxing: NonCallableMock,
27        mock_make_file_manager: NonCallableMock,
28        mock_parse_native_yaml: NonCallableMock,
29        mock_get_custom_build_selector: NonCallableMock,
30    ) -> None:
31        temp_file = tempfile.NamedTemporaryFile()
32        temp_file.write(b"- aten::add.Tensor")
33        temp_file.seek(0)
34        args = [
35            f"--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}",
36            "--op-selection-yaml-path=path2",
37        ]
38        gen_unboxing.main(args)
39        mock_get_custom_build_selector.assert_called_once_with(
40            ["aten::add.Tensor"], "path2"
41        )
42        temp_file.close()
43
44    def test_get_custom_build_selector_with_both_allowlist_and_yaml(
45        self,
46        mock_gen_unboxing: NonCallableMock,
47        mock_make_file_manager: NonCallableMock,
48        mock_parse_native_yaml: NonCallableMock,
49        mock_get_custom_build_selector: NonCallableMock,
50    ) -> None:
51        temp_file = tempfile.NamedTemporaryFile()
52        temp_file.write(b"- aten::add.Tensor")
53        temp_file.seek(0)
54        args = [
55            "--op-registration-allowlist=op1",
56            "--TEST-ONLY-op-registration-allowlist-yaml-path={temp_file.name}",
57            "--op-selection-yaml-path=path2",
58        ]
59        gen_unboxing.main(args)
60        mock_get_custom_build_selector.assert_called_once_with(["op1"], "path2")
61        temp_file.close()
62
63
64if __name__ == "__main__":
65    unittest.main()
66