xref: /aosp_15_r20/external/pytorch/tools/test/test_selective_build.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import unittest
2
3from torchgen.model import Location, NativeFunction
4from torchgen.selective_build.operator import *  # noqa: F403
5from torchgen.selective_build.selector import (
6    combine_selective_builders,
7    SelectiveBuilder,
8)
9
10
11class TestSelectiveBuild(unittest.TestCase):
12    def test_selective_build_operator(self) -> None:
13        op = SelectiveBuildOperator(
14            "aten::add.int",
15            is_root_operator=True,
16            is_used_for_training=False,
17            include_all_overloads=False,
18            _debug_info=None,
19        )
20        self.assertTrue(op.is_root_operator)
21        self.assertFalse(op.is_used_for_training)
22        self.assertFalse(op.include_all_overloads)
23
24    def test_selector_factory(self) -> None:
25        yaml_config_v1 = """
26debug_info:
27  - model1@v100
28  - model2@v51
29operators:
30  aten::add:
31    is_used_for_training: No
32    is_root_operator: Yes
33    include_all_overloads: Yes
34  aten::add.int:
35    is_used_for_training: Yes
36    is_root_operator: No
37    include_all_overloads: No
38  aten::mul.int:
39    is_used_for_training: Yes
40    is_root_operator: No
41    include_all_overloads: No
42"""
43
44        yaml_config_v2 = """
45debug_info:
46  - model1@v100
47  - model2@v51
48operators:
49  aten::sub:
50    is_used_for_training: No
51    is_root_operator: Yes
52    include_all_overloads: No
53    debug_info:
54      - model1@v100
55  aten::sub.int:
56    is_used_for_training: Yes
57    is_root_operator: No
58    include_all_overloads: No
59"""
60
61        yaml_config_all = "include_all_operators: Yes"
62
63        yaml_config_invalid = "invalid:"
64
65        selector1 = SelectiveBuilder.from_yaml_str(yaml_config_v1)
66
67        self.assertTrue(selector1.is_operator_selected("aten::add"))
68        self.assertTrue(selector1.is_operator_selected("aten::add.int"))
69        # Overload name is not used for checking in v1.
70        self.assertTrue(selector1.is_operator_selected("aten::add.float"))
71
72        def gen():
73            return SelectiveBuilder.from_yaml_str(yaml_config_invalid)
74
75        self.assertRaises(Exception, gen)
76
77        selector_all = SelectiveBuilder.from_yaml_str(yaml_config_all)
78
79        self.assertTrue(selector_all.is_operator_selected("aten::add"))
80        self.assertTrue(selector_all.is_operator_selected("aten::sub"))
81        self.assertTrue(selector_all.is_operator_selected("aten::sub.int"))
82        self.assertTrue(selector_all.is_kernel_dtype_selected("add_kernel", "int32"))
83
84        selector2 = SelectiveBuilder.from_yaml_str(yaml_config_v2)
85
86        self.assertFalse(selector2.is_operator_selected("aten::add"))
87        self.assertTrue(selector2.is_operator_selected("aten::sub"))
88        self.assertTrue(selector2.is_operator_selected("aten::sub.int"))
89
90        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
91            ["aten::add", "aten::add.int", "aten::mul.int"],
92            False,
93            False,
94        )
95        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.float"))
96        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add"))
97        self.assertTrue(selector_legacy_v1.is_operator_selected("aten::add.int"))
98        self.assertFalse(selector_legacy_v1.is_operator_selected("aten::sub"))
99
100        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
101        self.assertFalse(
102            selector_legacy_v1.is_operator_selected_for_training("aten::add")
103        )
104
105        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
106            ["aten::add", "aten::add.int", "aten::mul.int"],
107            True,
108            False,
109        )
110
111        self.assertTrue(selector_legacy_v1.is_root_operator("aten::add"))
112        self.assertFalse(
113            selector_legacy_v1.is_operator_selected_for_training("aten::add")
114        )
115        self.assertTrue(selector_legacy_v1.is_root_operator("aten::add.float"))
116        self.assertFalse(
117            selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
118        )
119
120        selector_legacy_v1 = SelectiveBuilder.from_legacy_op_registration_allow_list(
121            ["aten::add", "aten::add.int", "aten::mul.int"],
122            False,
123            True,
124        )
125
126        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add"))
127        self.assertTrue(
128            selector_legacy_v1.is_operator_selected_for_training("aten::add")
129        )
130        self.assertFalse(selector_legacy_v1.is_root_operator("aten::add.float"))
131        self.assertTrue(
132            selector_legacy_v1.is_operator_selected_for_training("aten::add.float")
133        )
134
135    def test_operator_combine(self) -> None:
136        op1 = SelectiveBuildOperator(
137            "aten::add.int",
138            is_root_operator=True,
139            is_used_for_training=False,
140            include_all_overloads=False,
141            _debug_info=None,
142        )
143        op2 = SelectiveBuildOperator(
144            "aten::add.int",
145            is_root_operator=False,
146            is_used_for_training=False,
147            include_all_overloads=False,
148            _debug_info=None,
149        )
150        op3 = SelectiveBuildOperator(
151            "aten::add",
152            is_root_operator=True,
153            is_used_for_training=False,
154            include_all_overloads=False,
155            _debug_info=None,
156        )
157        op4 = SelectiveBuildOperator(
158            "aten::add.int",
159            is_root_operator=True,
160            is_used_for_training=True,
161            include_all_overloads=False,
162            _debug_info=None,
163        )
164
165        op5 = combine_operators(op1, op2)
166
167        self.assertTrue(op5.is_root_operator)
168        self.assertFalse(op5.is_used_for_training)
169
170        op6 = combine_operators(op1, op4)
171
172        self.assertTrue(op6.is_root_operator)
173        self.assertTrue(op6.is_used_for_training)
174
175        def gen_new_op():
176            return combine_operators(op1, op3)
177
178        self.assertRaises(Exception, gen_new_op)
179
180    def test_training_op_fetch(self) -> None:
181        yaml_config = """
182operators:
183  aten::add.int:
184    is_used_for_training: No
185    is_root_operator: Yes
186    include_all_overloads: No
187  aten::add:
188    is_used_for_training: Yes
189    is_root_operator: No
190    include_all_overloads: Yes
191"""
192
193        selector = SelectiveBuilder.from_yaml_str(yaml_config)
194        self.assertTrue(selector.is_operator_selected_for_training("aten::add.int"))
195        self.assertTrue(selector.is_operator_selected_for_training("aten::add"))
196
197    def test_kernel_dtypes(self) -> None:
198        yaml_config = """
199kernel_metadata:
200  add_kernel:
201    - int8
202    - int32
203  sub_kernel:
204    - int16
205    - int32
206  add/sub_kernel:
207    - float
208    - complex
209"""
210
211        selector = SelectiveBuilder.from_yaml_str(yaml_config)
212
213        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
214        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
215        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
216        self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
217        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
218
219        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
220        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
221        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
222        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
223
224    def test_merge_kernel_dtypes(self) -> None:
225        yaml_config1 = """
226kernel_metadata:
227  add_kernel:
228    - int8
229  add/sub_kernel:
230    - float
231    - complex
232    - none
233  mul_kernel:
234    - int8
235"""
236
237        yaml_config2 = """
238kernel_metadata:
239  add_kernel:
240    - int32
241  sub_kernel:
242    - int16
243    - int32
244  add/sub_kernel:
245    - float
246    - complex
247"""
248
249        selector1 = SelectiveBuilder.from_yaml_str(yaml_config1)
250        selector2 = SelectiveBuilder.from_yaml_str(yaml_config2)
251
252        selector = combine_selective_builders(selector1, selector2)
253
254        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
255        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
256        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "int16"))
257        self.assertFalse(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
258        self.assertFalse(selector.is_kernel_dtype_selected("add_kernel", "float"))
259
260        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "float"))
261        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "complex"))
262        self.assertTrue(selector.is_kernel_dtype_selected("add/sub_kernel", "none"))
263        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int16"))
264        self.assertFalse(selector.is_kernel_dtype_selected("add/sub_kernel", "int32"))
265
266        self.assertTrue(selector.is_kernel_dtype_selected("mul_kernel", "int8"))
267        self.assertFalse(selector.is_kernel_dtype_selected("mul_kernel", "int32"))
268
269    def test_all_kernel_dtypes_selected(self) -> None:
270        yaml_config = """
271include_all_non_op_selectives: True
272"""
273
274        selector = SelectiveBuilder.from_yaml_str(yaml_config)
275
276        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int32"))
277        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int8"))
278        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "int16"))
279        self.assertTrue(selector.is_kernel_dtype_selected("add1_kernel", "int32"))
280        self.assertTrue(selector.is_kernel_dtype_selected("add_kernel", "float"))
281
282    def test_custom_namespace_selected_correctly(self) -> None:
283        yaml_config = """
284operators:
285  aten::add.int:
286    is_used_for_training: No
287    is_root_operator: Yes
288    include_all_overloads: No
289  custom::add:
290    is_used_for_training: Yes
291    is_root_operator: No
292    include_all_overloads: Yes
293"""
294        selector = SelectiveBuilder.from_yaml_str(yaml_config)
295        native_function, _ = NativeFunction.from_yaml(
296            {"func": "custom::add() -> Tensor"},
297            loc=Location(__file__, 1),
298            valid_tags=set(),
299        )
300        self.assertTrue(selector.is_native_function_selected(native_function))
301
302
303class TestExecuTorchSelectiveBuild(unittest.TestCase):
304    def test_et_kernel_selected(self) -> None:
305        yaml_config = """
306et_kernel_metadata:
307  aten::add.out:
308   - "v1/6;0,1|6;0,1|6;0,1|6;0,1"
309  aten::sub.out:
310   - "v1/6;0,1|6;0,1|6;0,1|6;0,1"
311"""
312        selector = SelectiveBuilder.from_yaml_str(yaml_config)
313        self.assertListEqual(
314            ["v1/6;0,1|6;0,1|6;0,1|6;0,1"],
315            selector.et_get_selected_kernels(
316                "aten::add.out",
317                [
318                    "v1/6;0,1|6;0,1|6;0,1|6;0,1",
319                    "v1/3;0,1|3;0,1|3;0,1|3;0,1",
320                    "v1/6;1,0|6;0,1|6;0,1|6;0,1",
321                ],
322            ),
323        )
324        self.assertListEqual(
325            ["v1/6;0,1|6;0,1|6;0,1|6;0,1"],
326            selector.et_get_selected_kernels(
327                "aten::sub.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"]
328            ),
329        )
330        self.assertListEqual(
331            [],
332            selector.et_get_selected_kernels(
333                "aten::mul.out", ["v1/6;0,1|6;0,1|6;0,1|6;0,1"]
334            ),
335        )
336        # We don't use version for now.
337        self.assertListEqual(
338            ["v2/6;0,1|6;0,1|6;0,1|6;0,1"],
339            selector.et_get_selected_kernels(
340                "aten::add.out", ["v2/6;0,1|6;0,1|6;0,1|6;0,1"]
341            ),
342        )
343