xref: /aosp_15_r20/external/executorch/codegen/tools/test/test_gen_oplist.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import tempfile
10import unittest
11from typing import Dict, List
12from unittest.mock import NonCallableMock, patch
13
14import executorch.codegen.tools.gen_oplist as gen_oplist
15import yaml
16
17
18class TestGenOpList(unittest.TestCase):
19    def setUp(self):
20        self.temp_dir = tempfile.TemporaryDirectory()
21        self.ops_schema_yaml = os.path.join(self.temp_dir.name, "test.yaml")
22        with open(self.ops_schema_yaml, "w") as f:
23            f.write(
24                """
25- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
26  device_check: NoCheck   # TensorIterator
27  dispatch:
28    CPU: torch::executor::add_out_kernel
29
30- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
31  device_check: NoCheck   # TensorIterator
32  dispatch:
33    CPU: torch::executor::mul_out_kernel
34            """
35            )
36
37    @patch("executorch.codegen.tools.gen_oplist._get_operators")
38    @patch("executorch.codegen.tools.gen_oplist._dump_yaml")
39    def test_gen_op_list_with_wrong_path(
40        self,
41        mock_dump_yaml: NonCallableMock,
42        mock_get_operators: NonCallableMock,
43    ) -> None:
44        args = ["--output_path=wrong_path", "--model_file_path=path2"]
45        with self.assertRaises(RuntimeError):
46            gen_oplist.main(args)
47
48    @patch("executorch.codegen.tools.gen_oplist._get_kernel_metadata_for_model")
49    @patch("executorch.codegen.tools.gen_oplist._get_operators")
50    @patch("executorch.codegen.tools.gen_oplist._dump_yaml")
51    def test_gen_op_list_with_valid_model_path(
52        self,
53        mock_get_kernel_metadata_for_model: NonCallableMock,
54        mock_dump_yaml: NonCallableMock,
55        mock_get_operators: NonCallableMock,
56    ) -> None:
57        temp_file = tempfile.NamedTemporaryFile()
58        args = [
59            f"--output_path={os.path.join(self.temp_dir.name, 'output.yaml')}",
60            f"--model_file_path={temp_file.name}",
61        ]
62        gen_oplist.main(args)
63        mock_get_operators.assert_called_once_with(temp_file.name)
64        temp_file.close()
65
66    @patch("executorch.codegen.tools.gen_oplist._dump_yaml")
67    def test_gen_op_list_with_valid_root_ops(
68        self,
69        mock_dump_yaml: NonCallableMock,
70    ) -> None:
71        output_path = os.path.join(self.temp_dir.name, "output.yaml")
72        args = [
73            f"--output_path={output_path}",
74            "--root_ops=aten::add,aten::mul",
75        ]
76        gen_oplist.main(args)
77        mock_dump_yaml.assert_called_once_with(
78            ["aten::add", "aten::mul"],
79            output_path,
80            None,
81            {"aten::add": ["default"], "aten::mul": ["default"]},
82            False,
83        )
84
85    @patch("executorch.codegen.tools.gen_oplist._dump_yaml")
86    def test_gen_op_list_with_root_ops_and_dtypes(
87        self,
88        mock_dump_yaml: NonCallableMock,
89    ) -> None:
90        output_path = os.path.join(self.temp_dir.name, "output.yaml")
91        ops_dict = {
92            "aten::add": ["v1/3;0,1|3;0,1|3;0,1|3;0,1", "v1/6;0,1|6;0,1|6;0,1|6;0,1"],
93            "aten::mul": [],
94        }
95        args = [
96            f"--output_path={output_path}",
97            f"--ops_dict={json.dumps(ops_dict)}",
98        ]
99        gen_oplist.main(args)
100        mock_dump_yaml.assert_called_once_with(
101            ["aten::add", "aten::mul"],
102            output_path,
103            None,
104            {
105                "aten::add": [
106                    "v1/3;0,1|3;0,1|3;0,1|3;0,1",
107                    "v1/6;0,1|6;0,1|6;0,1|6;0,1",
108                ],
109                "aten::mul": ["default"],
110            },
111            False,
112        )
113
114    @patch("executorch.codegen.tools.gen_oplist._get_operators")
115    @patch("executorch.codegen.tools.gen_oplist._dump_yaml")
116    def test_gen_op_list_with_both_op_list_and_ops_schema_yaml_merges(
117        self,
118        mock_dump_yaml: NonCallableMock,
119        mock_get_operators: NonCallableMock,
120    ) -> None:
121        output_path = os.path.join(self.temp_dir.name, "output.yaml")
122        args = [
123            f"--output_path={output_path}",
124            "--root_ops=aten::relu.out",
125            f"--ops_schema_yaml_path={self.ops_schema_yaml}",
126        ]
127        gen_oplist.main(args)
128        mock_dump_yaml.assert_called_once_with(
129            ["aten::add.out", "aten::mul.out", "aten::relu.out"],
130            output_path,
131            "test.yaml",
132            {
133                "aten::relu.out": ["default"],
134                "aten::add.out": ["default"],
135                "aten::mul.out": ["default"],
136            },
137            False,
138        )
139
140    @patch("executorch.codegen.tools.gen_oplist._dump_yaml")
141    def test_gen_op_list_with_include_all_operators(
142        self,
143        mock_dump_yaml: NonCallableMock,
144    ) -> None:
145        output_path = os.path.join(self.temp_dir.name, "output.yaml")
146        args = [
147            f"--output_path={output_path}",
148            "--root_ops=aten::add,aten::mul",
149            "--include_all_operators",
150        ]
151        gen_oplist.main(args)
152        mock_dump_yaml.assert_called_once_with(
153            ["aten::add", "aten::mul"],
154            output_path,
155            None,
156            {"aten::add": ["default"], "aten::mul": ["default"]},
157            True,
158        )
159
160    def test_get_custom_build_selector_with_both_allowlist_and_yaml(
161        self,
162    ) -> None:
163        op_list = ["aten::add", "aten::mul"]
164        filename = os.path.join(self.temp_dir.name, "selected_operators.yaml")
165        gen_oplist._dump_yaml(op_list, filename, "model.pte")
166        self.assertTrue(os.path.isfile(filename))
167        with open(filename) as f:
168            es = yaml.safe_load(f)
169        ops = es["operators"]
170        self.assertEqual(len(ops), 2)
171        self.assertSetEqual(set(ops.keys()), set(op_list))
172
173    def test_gen_oplist_generates_from_root_ops(
174        self,
175    ) -> None:
176        filename = os.path.join(self.temp_dir.name, "selected_operators.yaml")
177        op_list = ["aten::add.out", "aten::mul.out", "aten::relu.out"]
178        comma = ","
179        args = [
180            f"--output_path={filename}",
181            f"--root_ops={comma.join(op_list)}",
182        ]
183        gen_oplist.main(args)
184        self.assertTrue(os.path.isfile(filename))
185        with open(filename) as f:
186            es = yaml.safe_load(f)
187        ops = es["operators"]
188        self.assertEqual(len(ops), 3)
189        self.assertSetEqual(set(ops.keys()), set(op_list))
190
191    def test_dump_operator_from_ops_schema_yaml(self) -> None:
192        ops = gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
193        self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])
194
195    def test_dump_operator_from_ops_schema_yaml_with_op_syntax(self) -> None:
196        ops_yaml = os.path.join(self.temp_dir.name, "ops.yaml")
197        with open(ops_yaml, "w") as f:
198            f.write(
199                """
200- op: add.out
201  device_check: NoCheck   # TensorIterator
202  dispatch:
203    CPU: torch::executor::add_out_kernel
204
205- op: mul.out
206  device_check: NoCheck   # TensorIterator
207  dispatch:
208    CPU: torch::executor::mul_out_kernel
209            """
210            )
211        ops = gen_oplist._get_et_kernel_metadata_from_ops_yaml(ops_yaml)
212        self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])
213
214    def test_dump_operator_from_ops_schema_yaml_with_mix_syntax(self) -> None:
215        mix_yaml = os.path.join(self.temp_dir.name, "mix.yaml")
216        with open(mix_yaml, "w") as f:
217            f.write(
218                """
219- op: add.out
220  device_check: NoCheck   # TensorIterator
221  dispatch:
222    CPU: torch::executor::add_out_kernel
223
224- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
225  device_check: NoCheck   # TensorIterator
226  dispatch:
227    CPU: torch::executor::mul_out_kernel
228            """
229            )
230        ops = gen_oplist._get_et_kernel_metadata_from_ops_yaml(mix_yaml)
231        self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"])
232
233    def test_get_kernel_metadata_from_ops_yaml(self) -> None:
234        metadata: Dict[str, List[str]] = (
235            gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml)
236        )
237
238        self.assertEqual(len(metadata), 2)
239
240        self.assertIn("aten::add.out", metadata)
241        # We only have one dtype/dim-order combo for add (float/0,1)
242        self.assertEqual(len(metadata["aten::add.out"]), 1)
243        self.assertEqual(
244            metadata["aten::add.out"][0],
245            "default",
246        )
247
248        self.assertIn("aten::mul.out", metadata)
249        self.assertEqual(len(metadata["aten::mul.out"]), 1)
250        self.assertEqual(
251            metadata["aten::mul.out"][0],
252            "default",
253        )
254
255    def tearDown(self):
256        self.temp_dir.cleanup()
257
258
259if __name__ == "__main__":
260    unittest.main()
261