1import unittest 2 3from torchgen.model import Location, NativeFunction 4from torchgen.selective_build.operator import * # noqa: F403 5from torchgen.selective_build.selector import ( 6 combine_selective_builders, 7 SelectiveBuilder, 8) 9 10 11class TestSelectiveBuild(unittest.TestCase): 12 def test_selective_build_operator(self) -> None: 13 op = SelectiveBuildOperator( 14 "aten::add.int", 15 is_root_operator=True, 16 is_used_for_training=False, 17 include_all_overloads=False, 18 _debug_info=None, 19 ) 20 self.assertTrue(op.is_root_operator) 21 self.assertFalse(op.is_used_for_training) 22 self.assertFalse(op.include_all_overloads) 23 24 def test_selector_factory(self) -> None: 25 yaml_config_v1 = """ 26debug_info: 27 - model1@v100 28 - model2@v51 29operators: 30 aten::add: 31 is_used_for_training: No 32 is_root_operator: Yes 33 include_all_overloads: Yes 34 aten::add.int: 35 is_used_for_training: Yes 36 is_root_operator: No 37 include_all_overloads: No 38 aten::mul.int: 39 is_used_for_training: Yes 40 is_root_operator: No 41 include_all_overloads: No 42""" 43 44 yaml_config_v2 = """ 45debug_info: 46 - model1@v100 47 - model2@v51 48operators: 49 aten::sub: 50 is_used_for_training: No 51 is_root_operator: Yes 52 include_all_overloads: No 53 debug_info: 54 - model1@v100 55 aten::sub.int: 56 is_used_for_training: Yes 57 is_root_operator: No 58 include_all_overloads: No 59""" 60 61 yaml_config_all = "include_all_operators: Yes" 62 63 yaml_config_invalid = "invalid:" 64 65 selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1) 66 67 self.assertTrue(selector1.is_operator_selected("aten::add")) 68 self.assertTrue(selector1.is_operator_selected("aten::add.int")) 69 # Overload name is not used for checking in v1. 70 self.assertTrue(selector1.is_operator_selected("aten::add.float")) 71 72 def gen(): 73 return SelectiveBuilder.from_yaml_str(yaml_config_invalid) 74 75 self.assertRaises(Exception, gen) 76 77 selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all) 78 79 self.assertTrue(selector_all.is_operator_selected("aten::add")) 80 self.assertTrue(selector_all.is_operator_selected("aten::sub")) 81 self.assertTrue(selector_all.is_operator_selected("aten::sub.int")) 82 self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32")) 83 84 selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2) 85 86 self.assertFalse(selector2.is_operator_selected("aten::add")) 87 self.assertTrue(selector2.is_operator_selected("aten::sub")) 88 self.assertTrue(selector2.is_operator_selected("aten::sub.int")) 89 90 selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( 91 ["aten::add", "aten::add.int", "aten::mul.int"], 92 False, 93 False, 94 ) 95 self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float")) 96 self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add")) 97 self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int")) 98 self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub")) 99 100 self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) 101 self.assertFalse( 102 selector_legacy_v1.is_operator_selected_for_training("aten::add") 103 ) 104 105 selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( 106 ["aten::add", "aten::add.int", "aten::mul.int"], 107 True, 108 False, 109 ) 110 111 self.assertTrue(selector_legacy_v1.is_root_operator("aten::add")) 112 self.assertFalse( 113 selector_legacy_v1.is_operator_selected_for_training("aten::add") 114 ) 115 self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float")) 116 self.assertFalse( 117 selector_legacy_v1.is_operator_selected_for_training("aten::add.float") 118 ) 119 120 selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( 121 ["aten::add", "aten::add.int", "aten::mul.int"], 122 False, 123 True, 124 ) 125 126 self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) 127 self.assertTrue( 128 selector_legacy_v1.is_operator_selected_for_training("aten::add") 129 ) 130 self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float")) 131 self.assertTrue( 132 selector_legacy_v1.is_operator_selected_for_training("aten::add.float") 133 ) 134 135 def test_operator_combine(self) -> None: 136 op1 = SelectiveBuildOperator( 137 "aten::add.int", 138 is_root_operator=True, 139 is_used_for_training=False, 140 include_all_overloads=False, 141 _debug_info=None, 142 ) 143 op2 = SelectiveBuildOperator( 144 "aten::add.int", 145 is_root_operator=False, 146 is_used_for_training=False, 147 include_all_overloads=False, 148 _debug_info=None, 149 ) 150 op3 = SelectiveBuildOperator( 151 "aten::add", 152 is_root_operator=True, 153 is_used_for_training=False, 154 include_all_overloads=False, 155 _debug_info=None, 156 ) 157 op4 = SelectiveBuildOperator( 158 "aten::add.int", 159 is_root_operator=True, 160 is_used_for_training=True, 161 include_all_overloads=False, 162 _debug_info=None, 163 ) 164 165 op5 = combine_operators(op1, op2) 166 167 self.assertTrue(op5.is_root_operator) 168 self.assertFalse(op5.is_used_for_training) 169 170 op6 = combine_operators(op1, op4) 171 172 self.assertTrue(op6.is_root_operator) 173 self.assertTrue(op6.is_used_for_training) 174 175 def gen_new_op(): 176 return combine_operators(op1, op3) 177 178 self.assertRaises(Exception, gen_new_op) 179 180 def test_training_op_fetch(self) -> None: 181 yaml_config = """ 182operators: 183 aten::add.int: 184 is_used_for_training: No 185 is_root_operator: Yes 186 include_all_overloads: No 187 aten::add: 188 is_used_for_training: Yes 189 is_root_operator: No 190 include_all_overloads: Yes 191""" 192 193 selector = SelectiveBuilder.from_yaml_str(yaml_config) 194 self.assertTrue(selector.is_operator_selected_for_training("aten::add.int")) 195 self.assertTrue(selector.is_operator_selected_for_training("aten::add")) 196 197 def test_kernel_dtypes(self) -> None: 198 yaml_config = """ 199kernel_metadata: 200 add_kernel: 201 - int8 202 - int32 203 sub_kernel: 204 - int16 205 - int32 206 add/sub_kernel: 207 - float 208 - complex 209""" 210 211 selector = SelectiveBuilder.from_yaml_str(yaml_config) 212 213 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) 214 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) 215 self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) 216 self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) 217 self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) 218 219 self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) 220 self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) 221 self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) 222 self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) 223 224 def test_merge_kernel_dtypes(self) -> None: 225 yaml_config1 = """ 226kernel_metadata: 227 add_kernel: 228 - int8 229 add/sub_kernel: 230 - float 231 - complex 232 - none 233 mul_kernel: 234 - int8 235""" 236 237 yaml_config2 = """ 238kernel_metadata: 239 add_kernel: 240 - int32 241 sub_kernel: 242 - int16 243 - int32 244 add/sub_kernel: 245 - float 246 - complex 247""" 248 249 selector1 = SelectiveBuilder.from_yaml_str(yaml_config1) 250 selector2 = SelectiveBuilder.from_yaml_str(yaml_config2) 251 252 selector = combine_selective_builders(selector1, selector2) 253 254 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) 255 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) 256 self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) 257 self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) 258 self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) 259 260 self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) 261 self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) 262 self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none")) 263 self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) 264 self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) 265 266 self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8")) 267 self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32")) 268 269 def test_all_kernel_dtypes_selected(self) -> None: 270 yaml_config = """ 271include_all_non_op_selectives: True 272""" 273 274 selector = SelectiveBuilder.from_yaml_str(yaml_config) 275 276 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) 277 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) 278 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16")) 279 self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32")) 280 self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float")) 281 282 def test_custom_namespace_selected_correctly(self) -> None: 283 yaml_config = """ 284operators: 285 aten::add.int: 286 is_used_for_training: No 287 is_root_operator: Yes 288 include_all_overloads: No 289 custom::add: 290 is_used_for_training: Yes 291 is_root_operator: No 292 include_all_overloads: Yes 293""" 294 selector = SelectiveBuilder.from_yaml_str(yaml_config) 295 native_function, _ = NativeFunction.from_yaml( 296 {"func": "custom::add() -> Tensor"}, 297 loc=Location(__file__, 1), 298 valid_tags=set(), 299 ) 300 self.assertTrue(selector.is_native_function_selected(native_function)) 301 302 303class TestExecuTorchSelectiveBuild(unittest.TestCase): 304 def test_et_kernel_selected(self) -> None: 305 yaml_config = """ 306et_kernel_metadata: 307 aten::add.out: 308 - "v1/6;0,1|6;0,1|6;0,1|6;0,1" 309 aten::sub.out: 310 - "v1/6;0,1|6;0,1|6;0,1|6;0,1" 311""" 312 selector = SelectiveBuilder.from_yaml_str(yaml_config) 313 self.assertListEqual( 314 ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], 315 selector.et_get_selected_kernels( 316 "aten::add.out", 317 [ 318 "v1/6;0,1|6;0,1|6;0,1|6;0,1", 319 "v1/3;0,1|3;0,1|3;0,1|3;0,1", 320 "v1/6;1,0|6;0,1|6;0,1|6;0,1", 321 ], 322 ), 323 ) 324 self.assertListEqual( 325 ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], 326 selector.et_get_selected_kernels( 327 "aten::sub.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] 328 ), 329 ) 330 self.assertListEqual( 331 [], 332 selector.et_get_selected_kernels( 333 "aten::mul.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] 334 ), 335 ) 336 # We don't use version for now. 337 self.assertListEqual( 338 ["v2/6;0,1|6;0,1|6;0,1|6;0,1"], 339 selector.et_get_selected_kernels( 340 "aten::add.out", ["v2/6;0,1|6;0,1|6;0,1|6;0,1"] 341 ), 342 ) 343