1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9import tempfile 10import unittest 11from typing import Dict, List 12from unittest.mock import NonCallableMock, patch 13 14import executorch.codegen.tools.gen_oplist as gen_oplist 15import yaml 16 17 18class TestGenOpList(unittest.TestCase): 19 def setUp(self): 20 self.temp_dir = tempfile.TemporaryDirectory() 21 self.ops_schema_yaml = os.path.join(self.temp_dir.name, "test.yaml") 22 with open(self.ops_schema_yaml, "w") as f: 23 f.write( 24 """ 25- func: add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) 26 device_check: NoCheck # TensorIterator 27 dispatch: 28 CPU: torch::executor::add_out_kernel 29 30- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 31 device_check: NoCheck # TensorIterator 32 dispatch: 33 CPU: torch::executor::mul_out_kernel 34 """ 35 ) 36 37 @patch("executorch.codegen.tools.gen_oplist._get_operators") 38 @patch("executorch.codegen.tools.gen_oplist._dump_yaml") 39 def test_gen_op_list_with_wrong_path( 40 self, 41 mock_dump_yaml: NonCallableMock, 42 mock_get_operators: NonCallableMock, 43 ) -> None: 44 args = ["--output_path=wrong_path", "--model_file_path=path2"] 45 with self.assertRaises(RuntimeError): 46 gen_oplist.main(args) 47 48 @patch("executorch.codegen.tools.gen_oplist._get_kernel_metadata_for_model") 49 @patch("executorch.codegen.tools.gen_oplist._get_operators") 50 @patch("executorch.codegen.tools.gen_oplist._dump_yaml") 51 def test_gen_op_list_with_valid_model_path( 52 self, 53 mock_get_kernel_metadata_for_model: NonCallableMock, 54 mock_dump_yaml: NonCallableMock, 55 mock_get_operators: NonCallableMock, 56 ) -> None: 57 temp_file = tempfile.NamedTemporaryFile() 58 args = [ 59 f"--output_path={os.path.join(self.temp_dir.name, 'output.yaml')}", 60 f"--model_file_path={temp_file.name}", 61 ] 62 gen_oplist.main(args) 63 mock_get_operators.assert_called_once_with(temp_file.name) 64 temp_file.close() 65 66 @patch("executorch.codegen.tools.gen_oplist._dump_yaml") 67 def test_gen_op_list_with_valid_root_ops( 68 self, 69 mock_dump_yaml: NonCallableMock, 70 ) -> None: 71 output_path = os.path.join(self.temp_dir.name, "output.yaml") 72 args = [ 73 f"--output_path={output_path}", 74 "--root_ops=aten::add,aten::mul", 75 ] 76 gen_oplist.main(args) 77 mock_dump_yaml.assert_called_once_with( 78 ["aten::add", "aten::mul"], 79 output_path, 80 None, 81 {"aten::add": ["default"], "aten::mul": ["default"]}, 82 False, 83 ) 84 85 @patch("executorch.codegen.tools.gen_oplist._dump_yaml") 86 def test_gen_op_list_with_root_ops_and_dtypes( 87 self, 88 mock_dump_yaml: NonCallableMock, 89 ) -> None: 90 output_path = os.path.join(self.temp_dir.name, "output.yaml") 91 ops_dict = { 92 "aten::add": ["v1/3;0,1|3;0,1|3;0,1|3;0,1", "v1/6;0,1|6;0,1|6;0,1|6;0,1"], 93 "aten::mul": [], 94 } 95 args = [ 96 f"--output_path={output_path}", 97 f"--ops_dict={json.dumps(ops_dict)}", 98 ] 99 gen_oplist.main(args) 100 mock_dump_yaml.assert_called_once_with( 101 ["aten::add", "aten::mul"], 102 output_path, 103 None, 104 { 105 "aten::add": [ 106 "v1/3;0,1|3;0,1|3;0,1|3;0,1", 107 "v1/6;0,1|6;0,1|6;0,1|6;0,1", 108 ], 109 "aten::mul": ["default"], 110 }, 111 False, 112 ) 113 114 @patch("executorch.codegen.tools.gen_oplist._get_operators") 115 @patch("executorch.codegen.tools.gen_oplist._dump_yaml") 116 def test_gen_op_list_with_both_op_list_and_ops_schema_yaml_merges( 117 self, 118 mock_dump_yaml: NonCallableMock, 119 mock_get_operators: NonCallableMock, 120 ) -> None: 121 output_path = os.path.join(self.temp_dir.name, "output.yaml") 122 args = [ 123 f"--output_path={output_path}", 124 "--root_ops=aten::relu.out", 125 f"--ops_schema_yaml_path={self.ops_schema_yaml}", 126 ] 127 gen_oplist.main(args) 128 mock_dump_yaml.assert_called_once_with( 129 ["aten::add.out", "aten::mul.out", "aten::relu.out"], 130 output_path, 131 "test.yaml", 132 { 133 "aten::relu.out": ["default"], 134 "aten::add.out": ["default"], 135 "aten::mul.out": ["default"], 136 }, 137 False, 138 ) 139 140 @patch("executorch.codegen.tools.gen_oplist._dump_yaml") 141 def test_gen_op_list_with_include_all_operators( 142 self, 143 mock_dump_yaml: NonCallableMock, 144 ) -> None: 145 output_path = os.path.join(self.temp_dir.name, "output.yaml") 146 args = [ 147 f"--output_path={output_path}", 148 "--root_ops=aten::add,aten::mul", 149 "--include_all_operators", 150 ] 151 gen_oplist.main(args) 152 mock_dump_yaml.assert_called_once_with( 153 ["aten::add", "aten::mul"], 154 output_path, 155 None, 156 {"aten::add": ["default"], "aten::mul": ["default"]}, 157 True, 158 ) 159 160 def test_get_custom_build_selector_with_both_allowlist_and_yaml( 161 self, 162 ) -> None: 163 op_list = ["aten::add", "aten::mul"] 164 filename = os.path.join(self.temp_dir.name, "selected_operators.yaml") 165 gen_oplist._dump_yaml(op_list, filename, "model.pte") 166 self.assertTrue(os.path.isfile(filename)) 167 with open(filename) as f: 168 es = yaml.safe_load(f) 169 ops = es["operators"] 170 self.assertEqual(len(ops), 2) 171 self.assertSetEqual(set(ops.keys()), set(op_list)) 172 173 def test_gen_oplist_generates_from_root_ops( 174 self, 175 ) -> None: 176 filename = os.path.join(self.temp_dir.name, "selected_operators.yaml") 177 op_list = ["aten::add.out", "aten::mul.out", "aten::relu.out"] 178 comma = "," 179 args = [ 180 f"--output_path={filename}", 181 f"--root_ops={comma.join(op_list)}", 182 ] 183 gen_oplist.main(args) 184 self.assertTrue(os.path.isfile(filename)) 185 with open(filename) as f: 186 es = yaml.safe_load(f) 187 ops = es["operators"] 188 self.assertEqual(len(ops), 3) 189 self.assertSetEqual(set(ops.keys()), set(op_list)) 190 191 def test_dump_operator_from_ops_schema_yaml(self) -> None: 192 ops = gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml) 193 self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"]) 194 195 def test_dump_operator_from_ops_schema_yaml_with_op_syntax(self) -> None: 196 ops_yaml = os.path.join(self.temp_dir.name, "ops.yaml") 197 with open(ops_yaml, "w") as f: 198 f.write( 199 """ 200- op: add.out 201 device_check: NoCheck # TensorIterator 202 dispatch: 203 CPU: torch::executor::add_out_kernel 204 205- op: mul.out 206 device_check: NoCheck # TensorIterator 207 dispatch: 208 CPU: torch::executor::mul_out_kernel 209 """ 210 ) 211 ops = gen_oplist._get_et_kernel_metadata_from_ops_yaml(ops_yaml) 212 self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"]) 213 214 def test_dump_operator_from_ops_schema_yaml_with_mix_syntax(self) -> None: 215 mix_yaml = os.path.join(self.temp_dir.name, "mix.yaml") 216 with open(mix_yaml, "w") as f: 217 f.write( 218 """ 219- op: add.out 220 device_check: NoCheck # TensorIterator 221 dispatch: 222 CPU: torch::executor::add_out_kernel 223 224- func: mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) 225 device_check: NoCheck # TensorIterator 226 dispatch: 227 CPU: torch::executor::mul_out_kernel 228 """ 229 ) 230 ops = gen_oplist._get_et_kernel_metadata_from_ops_yaml(mix_yaml) 231 self.assertListEqual(sorted(ops.keys()), ["aten::add.out", "aten::mul.out"]) 232 233 def test_get_kernel_metadata_from_ops_yaml(self) -> None: 234 metadata: Dict[str, List[str]] = ( 235 gen_oplist._get_et_kernel_metadata_from_ops_yaml(self.ops_schema_yaml) 236 ) 237 238 self.assertEqual(len(metadata), 2) 239 240 self.assertIn("aten::add.out", metadata) 241 # We only have one dtype/dim-order combo for add (float/0,1) 242 self.assertEqual(len(metadata["aten::add.out"]), 1) 243 self.assertEqual( 244 metadata["aten::add.out"][0], 245 "default", 246 ) 247 248 self.assertIn("aten::mul.out", metadata) 249 self.assertEqual(len(metadata["aten::mul.out"]), 1) 250 self.assertEqual( 251 metadata["aten::mul.out"][0], 252 "default", 253 ) 254 255 def tearDown(self): 256 self.temp_dir.cleanup() 257 258 259if __name__ == "__main__": 260 unittest.main() 261