import unittest from torchgen.model import Location, NativeFunction from torchgen.selective_build.operator import * # noqa: F403 from torchgen.selective_build.selector import ( combine_selective_builders, SelectiveBuilder, ) class TestSelectiveBuild(unittest.TestCase): def test_selective_build_operator(self) -> None: op = SelectiveBuildOperator( "aten::add.int", is_root_operator=True, is_used_for_training=False, include_all_overloads=False, _debug_info=None, ) self.assertTrue(op.is_root_operator) self.assertFalse(op.is_used_for_training) self.assertFalse(op.include_all_overloads) def test_selector_factory(self) -> None: yaml_config_v1 = """ debug_info: - model1@v100 - model2@v51 operators: aten::add: is_used_for_training: No is_root_operator: Yes include_all_overloads: Yes aten::add.int: is_used_for_training: Yes is_root_operator: No include_all_overloads: No aten::mul.int: is_used_for_training: Yes is_root_operator: No include_all_overloads: No """ yaml_config_v2 = """ debug_info: - model1@v100 - model2@v51 operators: aten::sub: is_used_for_training: No is_root_operator: Yes include_all_overloads: No debug_info: - model1@v100 aten::sub.int: is_used_for_training: Yes is_root_operator: No include_all_overloads: No """ yaml_config_all = "include_all_operators: Yes" yaml_config_invalid = "invalid:" selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1) self.assertTrue(selector1.is_operator_selected("aten::add")) self.assertTrue(selector1.is_operator_selected("aten::add.int")) # Overload name is not used for checking in v1. self.assertTrue(selector1.is_operator_selected("aten::add.float")) def gen(): return SelectiveBuilder.from_yaml_str(yaml_config_invalid) self.assertRaises(Exception, gen) selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all) self.assertTrue(selector_all.is_operator_selected("aten::add")) self.assertTrue(selector_all.is_operator_selected("aten::sub")) self.assertTrue(selector_all.is_operator_selected("aten::sub.int")) self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32")) selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2) self.assertFalse(selector2.is_operator_selected("aten::add")) self.assertTrue(selector2.is_operator_selected("aten::sub")) self.assertTrue(selector2.is_operator_selected("aten::sub.int")) selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( ["aten::add", "aten::add.int", "aten::mul.int"], False, False, ) self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float")) self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add")) self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int")) self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub")) self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) self.assertFalse( selector_legacy_v1.is_operator_selected_for_training("aten::add") ) selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( ["aten::add", "aten::add.int", "aten::mul.int"], True, False, ) self.assertTrue(selector_legacy_v1.is_root_operator("aten::add")) self.assertFalse( selector_legacy_v1.is_operator_selected_for_training("aten::add") ) self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float")) self.assertFalse( selector_legacy_v1.is_operator_selected_for_training("aten::add.float") ) selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list( ["aten::add", "aten::add.int", "aten::mul.int"], False, True, ) self.assertFalse(selector_legacy_v1.is_root_operator("aten::add")) self.assertTrue( selector_legacy_v1.is_operator_selected_for_training("aten::add") ) self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float")) self.assertTrue( selector_legacy_v1.is_operator_selected_for_training("aten::add.float") ) def test_operator_combine(self) -> None: op1 = SelectiveBuildOperator( "aten::add.int", is_root_operator=True, is_used_for_training=False, include_all_overloads=False, _debug_info=None, ) op2 = SelectiveBuildOperator( "aten::add.int", is_root_operator=False, is_used_for_training=False, include_all_overloads=False, _debug_info=None, ) op3 = SelectiveBuildOperator( "aten::add", is_root_operator=True, is_used_for_training=False, include_all_overloads=False, _debug_info=None, ) op4 = SelectiveBuildOperator( "aten::add.int", is_root_operator=True, is_used_for_training=True, include_all_overloads=False, _debug_info=None, ) op5 = combine_operators(op1, op2) self.assertTrue(op5.is_root_operator) self.assertFalse(op5.is_used_for_training) op6 = combine_operators(op1, op4) self.assertTrue(op6.is_root_operator) self.assertTrue(op6.is_used_for_training) def gen_new_op(): return combine_operators(op1, op3) self.assertRaises(Exception, gen_new_op) def test_training_op_fetch(self) -> None: yaml_config = """ operators: aten::add.int: is_used_for_training: No is_root_operator: Yes include_all_overloads: No aten::add: is_used_for_training: Yes is_root_operator: No include_all_overloads: Yes """ selector = SelectiveBuilder.from_yaml_str(yaml_config) self.assertTrue(selector.is_operator_selected_for_training("aten::add.int")) self.assertTrue(selector.is_operator_selected_for_training("aten::add")) def test_kernel_dtypes(self) -> None: yaml_config = """ kernel_metadata: add_kernel: - int8 - int32 sub_kernel: - int16 - int32 add/sub_kernel: - float - complex """ selector = SelectiveBuilder.from_yaml_str(yaml_config) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) def test_merge_kernel_dtypes(self) -> None: yaml_config1 = """ kernel_metadata: add_kernel: - int8 add/sub_kernel: - float - complex - none mul_kernel: - int8 """ yaml_config2 = """ kernel_metadata: add_kernel: - int32 sub_kernel: - int16 - int32 add/sub_kernel: - float - complex """ selector1 = SelectiveBuilder.from_yaml_str(yaml_config1) selector2 = SelectiveBuilder.from_yaml_str(yaml_config2) selector = combine_selective_builders(selector1, selector2) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16")) self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32")) self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float")) self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float")) self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex")) self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16")) self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8")) self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32")) def test_all_kernel_dtypes_selected(self) -> None: yaml_config = """ include_all_non_op_selectives: True """ selector = SelectiveBuilder.from_yaml_str(yaml_config) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16")) self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32")) self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float")) def test_custom_namespace_selected_correctly(self) -> None: yaml_config = """ operators: aten::add.int: is_used_for_training: No is_root_operator: Yes include_all_overloads: No custom::add: is_used_for_training: Yes is_root_operator: No include_all_overloads: Yes """ selector = SelectiveBuilder.from_yaml_str(yaml_config) native_function, _ = NativeFunction.from_yaml( {"func": "custom::add() -> Tensor"}, loc=Location(__file__, 1), valid_tags=set(), ) self.assertTrue(selector.is_native_function_selected(native_function)) class TestExecuTorchSelectiveBuild(unittest.TestCase): def test_et_kernel_selected(self) -> None: yaml_config = """ et_kernel_metadata: aten::add.out: - "v1/6;0,1|6;0,1|6;0,1|6;0,1" aten::sub.out: - "v1/6;0,1|6;0,1|6;0,1|6;0,1" """ selector = SelectiveBuilder.from_yaml_str(yaml_config) self.assertListEqual( ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], selector.et_get_selected_kernels( "aten::add.out", [ "v1/6;0,1|6;0,1|6;0,1|6;0,1", "v1/3;0,1|3;0,1|3;0,1|3;0,1", "v1/6;1,0|6;0,1|6;0,1|6;0,1", ], ), ) self.assertListEqual( ["v1/6;0,1|6;0,1|6;0,1|6;0,1"], selector.et_get_selected_kernels( "aten::sub.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] ), ) self.assertListEqual( [], selector.et_get_selected_kernels( "aten::mul.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"] ), ) # We don't use version for now. self.assertListEqual( ["v2/6;0,1|6;0,1|6;0,1|6;0,1"], selector.et_get_selected_kernels( "aten::add.out", ["v2/6;0,1|6;0,1|6;0,1|6;0,1"] ), )