xref: /aosp_15_r20/external/pytorch/tools/test/test_codegen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport dataclasses
4*da0073e9SAndroid Build Coastguard Workerimport typing
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport yaml
9*da0073e9SAndroid Build Coastguard Workerfrom tools.autograd import gen_autograd_functions, load_derivatives
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom torchgen import dest
12*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import CppSignatureGroup, DispatcherSignature
13*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import native_function_manager
14*da0073e9SAndroid Build Coastguard Workerfrom torchgen.gen import (
15*da0073e9SAndroid Build Coastguard Worker    get_native_function_declarations,
16*da0073e9SAndroid Build Coastguard Worker    get_native_function_schema_registrations,
17*da0073e9SAndroid Build Coastguard Worker    LineLoader,
18*da0073e9SAndroid Build Coastguard Worker    static_dispatch,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
21*da0073e9SAndroid Build Coastguard Worker    BackendIndex,
22*da0073e9SAndroid Build Coastguard Worker    BackendMetadata,
23*da0073e9SAndroid Build Coastguard Worker    DispatchKey,
24*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
25*da0073e9SAndroid Build Coastguard Worker    Location,
26*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
27*da0073e9SAndroid Build Coastguard Worker    OperatorName,
28*da0073e9SAndroid Build Coastguard Worker)
29*da0073e9SAndroid Build Coastguard Workerfrom torchgen.native_function_generation import add_generated_native_functions
30*da0073e9SAndroid Build Coastguard Workerfrom torchgen.selective_build.selector import SelectiveBuilder
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerclass TestCreateDerivative(unittest.TestCase):
34*da0073e9SAndroid Build Coastguard Worker    def test_named_grads(self) -> None:
35*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(
36*da0073e9SAndroid Build Coastguard Worker            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
37*da0073e9SAndroid Build Coastguard Worker        )
38*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        derivative = load_derivatives.create_derivative(
41*da0073e9SAndroid Build Coastguard Worker            native_function,
42*da0073e9SAndroid Build Coastguard Worker            formula="func_backward(grad_x, grad_y)",
43*da0073e9SAndroid Build Coastguard Worker            var_names=(),
44*da0073e9SAndroid Build Coastguard Worker            available_named_gradients=["grad_x", "grad_y"],
45*da0073e9SAndroid Build Coastguard Worker        )
46*da0073e9SAndroid Build Coastguard Worker        self.assertSetEqual(derivative.named_gradients, {"grad_x", "grad_y"})
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker    def test_non_differentiable_output(self) -> None:
49*da0073e9SAndroid Build Coastguard Worker        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
50*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(specification)
51*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        _, differentiability_info = load_derivatives.create_differentiability_info(
54*da0073e9SAndroid Build Coastguard Worker            defn_dict={
55*da0073e9SAndroid Build Coastguard Worker                "name": specification,
56*da0073e9SAndroid Build Coastguard Worker                "dispatch": {"Default": {"a": "grads[0]", "b": "grads[2]"}},
57*da0073e9SAndroid Build Coastguard Worker            },
58*da0073e9SAndroid Build Coastguard Worker            functions_by_signature={schema.signature(): [native_function]},
59*da0073e9SAndroid Build Coastguard Worker            functions_by_schema={specification: native_function},
60*da0073e9SAndroid Build Coastguard Worker            op_counter=typing.Counter[str](),
61*da0073e9SAndroid Build Coastguard Worker            used_dispatch_keys=set(),
62*da0073e9SAndroid Build Coastguard Worker        )
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        self.assertSequenceEqual(
65*da0073e9SAndroid Build Coastguard Worker            differentiability_info["Default"].available_named_gradients,
66*da0073e9SAndroid Build Coastguard Worker            # grad_y is not present because y is a
67*da0073e9SAndroid Build Coastguard Worker            # bool and thus not differentiable.
68*da0073e9SAndroid Build Coastguard Worker            ["grad_x", "grad_z"],
69*da0073e9SAndroid Build Coastguard Worker        )
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def test_indexed_grads(self) -> None:
72*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(
73*da0073e9SAndroid Build Coastguard Worker            "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
74*da0073e9SAndroid Build Coastguard Worker        )
75*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        derivative = load_derivatives.create_derivative(
78*da0073e9SAndroid Build Coastguard Worker            native_function,
79*da0073e9SAndroid Build Coastguard Worker            formula="func_backward(grads[0], grads[1])",
80*da0073e9SAndroid Build Coastguard Worker            var_names=(),
81*da0073e9SAndroid Build Coastguard Worker            available_named_gradients=["grad_x", "grad_y"],
82*da0073e9SAndroid Build Coastguard Worker        )
83*da0073e9SAndroid Build Coastguard Worker        self.assertSetEqual(derivative.named_gradients, set())
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    def test_named_grads_and_indexed_grads(self) -> None:
86*da0073e9SAndroid Build Coastguard Worker        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y)"
87*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(specification)
88*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
91*da0073e9SAndroid Build Coastguard Worker            RuntimeError, 'illegally mixes use of "grad_RETURN_NAME"'
92*da0073e9SAndroid Build Coastguard Worker        ):
93*da0073e9SAndroid Build Coastguard Worker            load_derivatives.create_differentiability_info(
94*da0073e9SAndroid Build Coastguard Worker                defn_dict={
95*da0073e9SAndroid Build Coastguard Worker                    "name": specification,
96*da0073e9SAndroid Build Coastguard Worker                    # Uh-oh, the derivatives reference gradients by
97*da0073e9SAndroid Build Coastguard Worker                    # name and by index.
98*da0073e9SAndroid Build Coastguard Worker                    "dispatch": {
99*da0073e9SAndroid Build Coastguard Worker                        "Default": {
100*da0073e9SAndroid Build Coastguard Worker                            "a": "grad_x",
101*da0073e9SAndroid Build Coastguard Worker                            "b": "grads[1]",
102*da0073e9SAndroid Build Coastguard Worker                        }
103*da0073e9SAndroid Build Coastguard Worker                    },
104*da0073e9SAndroid Build Coastguard Worker                },
105*da0073e9SAndroid Build Coastguard Worker                functions_by_signature={schema.signature(): [native_function]},
106*da0073e9SAndroid Build Coastguard Worker                functions_by_schema={specification: native_function},
107*da0073e9SAndroid Build Coastguard Worker                op_counter=typing.Counter[str](),
108*da0073e9SAndroid Build Coastguard Worker                used_dispatch_keys=set(),
109*da0073e9SAndroid Build Coastguard Worker            )
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerclass TestGenAutogradFunctions(unittest.TestCase):
113*da0073e9SAndroid Build Coastguard Worker    def test_non_differentiable_output_invalid_type(self) -> None:
114*da0073e9SAndroid Build Coastguard Worker        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
115*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(specification)
116*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker        _, differentiability_info = load_derivatives.create_differentiability_info(
119*da0073e9SAndroid Build Coastguard Worker            defn_dict={
120*da0073e9SAndroid Build Coastguard Worker                "name": specification,
121*da0073e9SAndroid Build Coastguard Worker                "dispatch": {
122*da0073e9SAndroid Build Coastguard Worker                    "Default": {
123*da0073e9SAndroid Build Coastguard Worker                        "a": "grad_x",
124*da0073e9SAndroid Build Coastguard Worker                        "b": "grad_z",
125*da0073e9SAndroid Build Coastguard Worker                    }
126*da0073e9SAndroid Build Coastguard Worker                },
127*da0073e9SAndroid Build Coastguard Worker            },
128*da0073e9SAndroid Build Coastguard Worker            functions_by_signature={schema.signature(): [native_function]},
129*da0073e9SAndroid Build Coastguard Worker            functions_by_schema={specification: native_function},
130*da0073e9SAndroid Build Coastguard Worker            op_counter=typing.Counter[str](),
131*da0073e9SAndroid Build Coastguard Worker            used_dispatch_keys=set(),
132*da0073e9SAndroid Build Coastguard Worker        )
133*da0073e9SAndroid Build Coastguard Worker        definition = gen_autograd_functions.process_function(
134*da0073e9SAndroid Build Coastguard Worker            differentiability_info["Default"],
135*da0073e9SAndroid Build Coastguard Worker            gen_autograd_functions.FUNCTION_DEFINITION,
136*da0073e9SAndroid Build Coastguard Worker        )
137*da0073e9SAndroid Build Coastguard Worker        # grad_z should map to grads[1], not grads[2] because output 1
138*da0073e9SAndroid Build Coastguard Worker        # (y) is not differentiable.
139*da0073e9SAndroid Build Coastguard Worker        assert "grad_z = grads[2]" not in definition
140*da0073e9SAndroid Build Coastguard Worker        assert "grad_z = grads[1]" in definition
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    def test_non_differentiable_output_output_differentiability(self) -> None:
143*da0073e9SAndroid Build Coastguard Worker        specification = "func(Tensor a, Tensor b) -> (Tensor x, Tensor y, Tensor z)"
144*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(specification)
145*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        _, differentiability_info = load_derivatives.create_differentiability_info(
148*da0073e9SAndroid Build Coastguard Worker            defn_dict={
149*da0073e9SAndroid Build Coastguard Worker                "name": specification,
150*da0073e9SAndroid Build Coastguard Worker                "dispatch": {
151*da0073e9SAndroid Build Coastguard Worker                    "Default": {
152*da0073e9SAndroid Build Coastguard Worker                        "a": "grad_x",
153*da0073e9SAndroid Build Coastguard Worker                        "b": "grad_z",
154*da0073e9SAndroid Build Coastguard Worker                    },
155*da0073e9SAndroid Build Coastguard Worker                    "AutogradNestedTensor": {
156*da0073e9SAndroid Build Coastguard Worker                        "a": "grad_z",
157*da0073e9SAndroid Build Coastguard Worker                        "b": "grad_x",
158*da0073e9SAndroid Build Coastguard Worker                    },
159*da0073e9SAndroid Build Coastguard Worker                },
160*da0073e9SAndroid Build Coastguard Worker                "output_differentiability": [True, False, True],
161*da0073e9SAndroid Build Coastguard Worker            },
162*da0073e9SAndroid Build Coastguard Worker            functions_by_signature={schema.signature(): [native_function]},
163*da0073e9SAndroid Build Coastguard Worker            functions_by_schema={specification: native_function},
164*da0073e9SAndroid Build Coastguard Worker            op_counter=typing.Counter[str](),
165*da0073e9SAndroid Build Coastguard Worker            used_dispatch_keys=set(),
166*da0073e9SAndroid Build Coastguard Worker        )
167*da0073e9SAndroid Build Coastguard Worker        default_definition = gen_autograd_functions.process_function(
168*da0073e9SAndroid Build Coastguard Worker            differentiability_info["Default"],
169*da0073e9SAndroid Build Coastguard Worker            gen_autograd_functions.FUNCTION_DEFINITION,
170*da0073e9SAndroid Build Coastguard Worker        )
171*da0073e9SAndroid Build Coastguard Worker        # grad_z should map to grads[1], not grads[2] because output 1
172*da0073e9SAndroid Build Coastguard Worker        # (y) is not differentiable.
173*da0073e9SAndroid Build Coastguard Worker        assert "grad_z = grads[2]" not in default_definition
174*da0073e9SAndroid Build Coastguard Worker        assert "grad_z = grads[1]" in default_definition
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        nested_tensor_definition = gen_autograd_functions.process_function(
177*da0073e9SAndroid Build Coastguard Worker            differentiability_info["AutogradNestedTensor"],
178*da0073e9SAndroid Build Coastguard Worker            gen_autograd_functions.FUNCTION_DEFINITION,
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker        assert "grad_z = grads[2]" not in nested_tensor_definition
181*da0073e9SAndroid Build Coastguard Worker        assert "grad_z = grads[1]" in nested_tensor_definition
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker    def test_register_bogus_dispatch_key(self) -> None:
184*da0073e9SAndroid Build Coastguard Worker        specification = "func(Tensor a, Tensor b) -> (Tensor x, bool y, Tensor z)"
185*da0073e9SAndroid Build Coastguard Worker        schema = FunctionSchema.parse(specification)
186*da0073e9SAndroid Build Coastguard Worker        native_function = dataclasses.replace(DEFAULT_NATIVE_FUNCTION, func=schema)
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
189*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
190*da0073e9SAndroid Build Coastguard Worker            "Invalid dispatch key AutogradRandomTensor in derivatives.yaml for",
191*da0073e9SAndroid Build Coastguard Worker        ):
192*da0073e9SAndroid Build Coastguard Worker            load_derivatives.create_differentiability_info(
193*da0073e9SAndroid Build Coastguard Worker                defn_dict={
194*da0073e9SAndroid Build Coastguard Worker                    "name": specification,
195*da0073e9SAndroid Build Coastguard Worker                    "dispatch": {
196*da0073e9SAndroid Build Coastguard Worker                        "Default": {
197*da0073e9SAndroid Build Coastguard Worker                            "a": "grad_x",
198*da0073e9SAndroid Build Coastguard Worker                            "b": "grad_z",
199*da0073e9SAndroid Build Coastguard Worker                        },
200*da0073e9SAndroid Build Coastguard Worker                        "AutogradRandomTensor": {
201*da0073e9SAndroid Build Coastguard Worker                            "a": "grad_x",
202*da0073e9SAndroid Build Coastguard Worker                            "b": "grad_z",
203*da0073e9SAndroid Build Coastguard Worker                        },
204*da0073e9SAndroid Build Coastguard Worker                    },
205*da0073e9SAndroid Build Coastguard Worker                },
206*da0073e9SAndroid Build Coastguard Worker                functions_by_signature={schema.signature(): [native_function]},
207*da0073e9SAndroid Build Coastguard Worker                functions_by_schema={specification: native_function},
208*da0073e9SAndroid Build Coastguard Worker                op_counter=typing.Counter[str](),
209*da0073e9SAndroid Build Coastguard Worker                used_dispatch_keys=set(),
210*da0073e9SAndroid Build Coastguard Worker            )
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Workerclass TestGenSchemaRegistration(unittest.TestCase):
214*da0073e9SAndroid Build Coastguard Worker    def setUp(self) -> None:
215*da0073e9SAndroid Build Coastguard Worker        self.selector = SelectiveBuilder.get_nop_selector()
216*da0073e9SAndroid Build Coastguard Worker        self.custom_native_function, _ = NativeFunction.from_yaml(
217*da0073e9SAndroid Build Coastguard Worker            {"func": "custom::func() -> bool"},
218*da0073e9SAndroid Build Coastguard Worker            loc=Location(__file__, 1),
219*da0073e9SAndroid Build Coastguard Worker            valid_tags=set(),
220*da0073e9SAndroid Build Coastguard Worker        )
221*da0073e9SAndroid Build Coastguard Worker        (
222*da0073e9SAndroid Build Coastguard Worker            self.fragment_custom_native_function,
223*da0073e9SAndroid Build Coastguard Worker            _,
224*da0073e9SAndroid Build Coastguard Worker        ) = NativeFunction.from_yaml(
225*da0073e9SAndroid Build Coastguard Worker            {"func": "quantized_decomposed::func() -> bool"},
226*da0073e9SAndroid Build Coastguard Worker            loc=Location(__file__, 1),
227*da0073e9SAndroid Build Coastguard Worker            valid_tags=set(),
228*da0073e9SAndroid Build Coastguard Worker        )
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker    def test_default_namespace_schema_registration_code_valid(self) -> None:
231*da0073e9SAndroid Build Coastguard Worker        native_functions = [DEFAULT_NATIVE_FUNCTION]
232*da0073e9SAndroid Build Coastguard Worker        registrations, _ = get_native_function_schema_registrations(
233*da0073e9SAndroid Build Coastguard Worker            native_functions=native_functions,
234*da0073e9SAndroid Build Coastguard Worker            schema_selector=self.selector,
235*da0073e9SAndroid Build Coastguard Worker        )
236*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(registrations, ['m.def("func() -> bool", {});\n'])
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    def test_custom_namespace_schema_registration_code_valid(self) -> None:
239*da0073e9SAndroid Build Coastguard Worker        _, registrations = get_native_function_schema_registrations(
240*da0073e9SAndroid Build Coastguard Worker            native_functions=[self.custom_native_function],
241*da0073e9SAndroid Build Coastguard Worker            schema_selector=self.selector,
242*da0073e9SAndroid Build Coastguard Worker        )
243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
244*da0073e9SAndroid Build Coastguard Worker            registrations,
245*da0073e9SAndroid Build Coastguard Worker            """
246*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom, m) {
247*da0073e9SAndroid Build Coastguard Worker  m.def("func() -> bool", {});
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker};""",
250*da0073e9SAndroid Build Coastguard Worker        )
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    def test_fragment_custom_namespace_schema_registration_code_valid(self) -> None:
253*da0073e9SAndroid Build Coastguard Worker        """Sometimes we want to extend an existing namespace, for example quantized
254*da0073e9SAndroid Build Coastguard Worker        namespace, which is already defined in native/quantized/library.cpp
255*da0073e9SAndroid Build Coastguard Worker        """
256*da0073e9SAndroid Build Coastguard Worker        _, registrations = get_native_function_schema_registrations(
257*da0073e9SAndroid Build Coastguard Worker            native_functions=[self.fragment_custom_native_function],
258*da0073e9SAndroid Build Coastguard Worker            schema_selector=self.selector,
259*da0073e9SAndroid Build Coastguard Worker        )
260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
261*da0073e9SAndroid Build Coastguard Worker            registrations,
262*da0073e9SAndroid Build Coastguard Worker            """
263*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY_FRAGMENT(quantized_decomposed, m) {
264*da0073e9SAndroid Build Coastguard Worker  m.def("func() -> bool", {});
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker};""",
267*da0073e9SAndroid Build Coastguard Worker        )
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    def test_mixed_namespace_schema_registration_code_valid(self) -> None:
270*da0073e9SAndroid Build Coastguard Worker        (
271*da0073e9SAndroid Build Coastguard Worker            aten_registrations,
272*da0073e9SAndroid Build Coastguard Worker            custom_registrations,
273*da0073e9SAndroid Build Coastguard Worker        ) = get_native_function_schema_registrations(
274*da0073e9SAndroid Build Coastguard Worker            native_functions=[DEFAULT_NATIVE_FUNCTION, self.custom_native_function],
275*da0073e9SAndroid Build Coastguard Worker            schema_selector=self.selector,
276*da0073e9SAndroid Build Coastguard Worker        )
277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
279*da0073e9SAndroid Build Coastguard Worker            custom_registrations,
280*da0073e9SAndroid Build Coastguard Worker            """
281*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom, m) {
282*da0073e9SAndroid Build Coastguard Worker  m.def("func() -> bool", {});
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker};""",
285*da0073e9SAndroid Build Coastguard Worker        )
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker    def test_3_namespaces_schema_registration_code_valid(self) -> None:
288*da0073e9SAndroid Build Coastguard Worker        custom2_native_function, _ = NativeFunction.from_yaml(
289*da0073e9SAndroid Build Coastguard Worker            {"func": "custom2::func() -> bool"},
290*da0073e9SAndroid Build Coastguard Worker            loc=Location(__file__, 1),
291*da0073e9SAndroid Build Coastguard Worker            valid_tags=set(),
292*da0073e9SAndroid Build Coastguard Worker        )
293*da0073e9SAndroid Build Coastguard Worker        (
294*da0073e9SAndroid Build Coastguard Worker            aten_registrations,
295*da0073e9SAndroid Build Coastguard Worker            custom_registrations,
296*da0073e9SAndroid Build Coastguard Worker        ) = get_native_function_schema_registrations(
297*da0073e9SAndroid Build Coastguard Worker            native_functions=[
298*da0073e9SAndroid Build Coastguard Worker                DEFAULT_NATIVE_FUNCTION,
299*da0073e9SAndroid Build Coastguard Worker                self.custom_native_function,
300*da0073e9SAndroid Build Coastguard Worker                custom2_native_function,
301*da0073e9SAndroid Build Coastguard Worker            ],
302*da0073e9SAndroid Build Coastguard Worker            schema_selector=self.selector,
303*da0073e9SAndroid Build Coastguard Worker        )
304*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(aten_registrations, ['m.def("func() -> bool", {});\n'])
305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
306*da0073e9SAndroid Build Coastguard Worker            custom_registrations,
307*da0073e9SAndroid Build Coastguard Worker            """
308*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom, m) {
309*da0073e9SAndroid Build Coastguard Worker  m.def("func() -> bool", {});
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard Worker};
312*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(custom2, m) {
313*da0073e9SAndroid Build Coastguard Worker  m.def("func() -> bool", {});
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker};""",
316*da0073e9SAndroid Build Coastguard Worker        )
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Workerclass TestGenNativeFunctionDeclaration(unittest.TestCase):
320*da0073e9SAndroid Build Coastguard Worker    def setUp(self) -> None:
321*da0073e9SAndroid Build Coastguard Worker        self.op_1_native_function, op_1_backend_index = NativeFunction.from_yaml(
322*da0073e9SAndroid Build Coastguard Worker            {"func": "op_1() -> bool", "dispatch": {"CPU": "kernel_1"}},
323*da0073e9SAndroid Build Coastguard Worker            loc=Location(__file__, 1),
324*da0073e9SAndroid Build Coastguard Worker            valid_tags=set(),
325*da0073e9SAndroid Build Coastguard Worker        )
326*da0073e9SAndroid Build Coastguard Worker        self.op_2_native_function, op_2_backend_index = NativeFunction.from_yaml(
327*da0073e9SAndroid Build Coastguard Worker            {
328*da0073e9SAndroid Build Coastguard Worker                "func": "op_2() -> bool",
329*da0073e9SAndroid Build Coastguard Worker                "dispatch": {"CPU": "kernel_2", "QuantizedCPU": "custom::kernel_3"},
330*da0073e9SAndroid Build Coastguard Worker            },
331*da0073e9SAndroid Build Coastguard Worker            loc=Location(__file__, 1),
332*da0073e9SAndroid Build Coastguard Worker            valid_tags=set(),
333*da0073e9SAndroid Build Coastguard Worker        )
334*da0073e9SAndroid Build Coastguard Worker
335*da0073e9SAndroid Build Coastguard Worker        backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = {
336*da0073e9SAndroid Build Coastguard Worker            DispatchKey.CPU: {},
337*da0073e9SAndroid Build Coastguard Worker            DispatchKey.QuantizedCPU: {},
338*da0073e9SAndroid Build Coastguard Worker        }
339*da0073e9SAndroid Build Coastguard Worker        BackendIndex.grow_index(backend_indices, op_1_backend_index)
340*da0073e9SAndroid Build Coastguard Worker        BackendIndex.grow_index(backend_indices, op_2_backend_index)
341*da0073e9SAndroid Build Coastguard Worker        self.backend_indices = {
342*da0073e9SAndroid Build Coastguard Worker            k: BackendIndex(
343*da0073e9SAndroid Build Coastguard Worker                dispatch_key=k,
344*da0073e9SAndroid Build Coastguard Worker                use_out_as_primary=True,
345*da0073e9SAndroid Build Coastguard Worker                external=False,
346*da0073e9SAndroid Build Coastguard Worker                device_guard=False,
347*da0073e9SAndroid Build Coastguard Worker                index=backend_indices[k],
348*da0073e9SAndroid Build Coastguard Worker            )
349*da0073e9SAndroid Build Coastguard Worker            for k in backend_indices
350*da0073e9SAndroid Build Coastguard Worker        }
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker    def test_native_function_declaration_1_op_2_ns_error(self) -> None:
353*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
354*da0073e9SAndroid Build Coastguard Worker            get_native_function_declarations(
355*da0073e9SAndroid Build Coastguard Worker                grouped_native_functions=[
356*da0073e9SAndroid Build Coastguard Worker                    self.op_1_native_function,
357*da0073e9SAndroid Build Coastguard Worker                    self.op_2_native_function,
358*da0073e9SAndroid Build Coastguard Worker                ],
359*da0073e9SAndroid Build Coastguard Worker                backend_indices=self.backend_indices,
360*da0073e9SAndroid Build Coastguard Worker                native_function_decl_gen=dest.compute_native_function_declaration,
361*da0073e9SAndroid Build Coastguard Worker            )
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker    def test_native_function_declaration_1_op_1_ns_valid(self) -> None:
364*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(self.op_1_native_function, NativeFunction)
365*da0073e9SAndroid Build Coastguard Worker        declaration = get_native_function_declarations(
366*da0073e9SAndroid Build Coastguard Worker            grouped_native_functions=[
367*da0073e9SAndroid Build Coastguard Worker                self.op_1_native_function,
368*da0073e9SAndroid Build Coastguard Worker            ],
369*da0073e9SAndroid Build Coastguard Worker            backend_indices=self.backend_indices,
370*da0073e9SAndroid Build Coastguard Worker            native_function_decl_gen=dest.compute_native_function_declaration,
371*da0073e9SAndroid Build Coastguard Worker        )
372*da0073e9SAndroid Build Coastguard Worker        target = """
373*da0073e9SAndroid Build Coastguard Workernamespace at {
374*da0073e9SAndroid Build Coastguard Workernamespace native {
375*da0073e9SAndroid Build Coastguard WorkerTORCH_API bool kernel_1();
376*da0073e9SAndroid Build Coastguard Worker} // namespace native
377*da0073e9SAndroid Build Coastguard Worker} // namespace at
378*da0073e9SAndroid Build Coastguard Worker        """
379*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("\n".join(declaration), target)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker# Test for native_function_generation
383*da0073e9SAndroid Build Coastguard Workerclass TestNativeFunctionGeneratrion(unittest.TestCase):
384*da0073e9SAndroid Build Coastguard Worker    def setUp(self) -> None:
385*da0073e9SAndroid Build Coastguard Worker        self.native_functions: list[NativeFunction] = []
386*da0073e9SAndroid Build Coastguard Worker        self.backend_indices: dict[
387*da0073e9SAndroid Build Coastguard Worker            DispatchKey, dict[OperatorName, BackendMetadata]
388*da0073e9SAndroid Build Coastguard Worker        ] = defaultdict(dict)
389*da0073e9SAndroid Build Coastguard Worker        yaml_entry = """
390*da0073e9SAndroid Build Coastguard Worker- func: op(Tensor self) -> Tensor
391*da0073e9SAndroid Build Coastguard Worker  dispatch:
392*da0073e9SAndroid Build Coastguard Worker    CompositeExplicitAutograd: op
393*da0073e9SAndroid Build Coastguard Worker  autogen: op.out
394*da0073e9SAndroid Build Coastguard Worker        """
395*da0073e9SAndroid Build Coastguard Worker        es = yaml.load(yaml_entry, Loader=LineLoader)
396*da0073e9SAndroid Build Coastguard Worker        self.one_return_func, m = NativeFunction.from_yaml(
397*da0073e9SAndroid Build Coastguard Worker            es[0], loc=Location(__file__, 1), valid_tags=set()
398*da0073e9SAndroid Build Coastguard Worker        )
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker        BackendIndex.grow_index(self.backend_indices, m)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker        self.two_returns_func, two_returns_backend_index = NativeFunction.from_yaml(
403*da0073e9SAndroid Build Coastguard Worker            {
404*da0073e9SAndroid Build Coastguard Worker                "func": "op_2() -> (Tensor, Tensor)",
405*da0073e9SAndroid Build Coastguard Worker                "dispatch": {"CPU": "kernel_1"},
406*da0073e9SAndroid Build Coastguard Worker                "autogen": "op_2.out",
407*da0073e9SAndroid Build Coastguard Worker            },
408*da0073e9SAndroid Build Coastguard Worker            loc=Location(__file__, 1),
409*da0073e9SAndroid Build Coastguard Worker            valid_tags=set(),
410*da0073e9SAndroid Build Coastguard Worker        )
411*da0073e9SAndroid Build Coastguard Worker        BackendIndex.grow_index(self.backend_indices, two_returns_backend_index)
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    def test_functional_variant_autogen_out_variant(self) -> None:
414*da0073e9SAndroid Build Coastguard Worker        native_functions = [self.one_return_func]
415*da0073e9SAndroid Build Coastguard Worker        add_generated_native_functions(native_functions, self.backend_indices)
416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(native_functions), 2)
417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
418*da0073e9SAndroid Build Coastguard Worker            str(native_functions[1].func),
419*da0073e9SAndroid Build Coastguard Worker            "op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)",
420*da0073e9SAndroid Build Coastguard Worker        )
421*da0073e9SAndroid Build Coastguard Worker        op_name = native_functions[1].func.name
422*da0073e9SAndroid Build Coastguard Worker        backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
423*da0073e9SAndroid Build Coastguard Worker            op_name
424*da0073e9SAndroid Build Coastguard Worker        ]
425*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(backend_metadata.kernel, "op_out")
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker    def test_functional_variant_autogen_out_variant_two_returns(self) -> None:
428*da0073e9SAndroid Build Coastguard Worker        native_functions = [self.two_returns_func]
429*da0073e9SAndroid Build Coastguard Worker        add_generated_native_functions(native_functions, self.backend_indices)
430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(native_functions), 2)
431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
432*da0073e9SAndroid Build Coastguard Worker            str(native_functions[1].func),
433*da0073e9SAndroid Build Coastguard Worker            "op_2.out(*, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))",
434*da0073e9SAndroid Build Coastguard Worker        )
435*da0073e9SAndroid Build Coastguard Worker        op_name = native_functions[1].func.name
436*da0073e9SAndroid Build Coastguard Worker        backend_metadata = self.backend_indices[DispatchKey.CompositeExplicitAutograd][
437*da0073e9SAndroid Build Coastguard Worker            op_name
438*da0073e9SAndroid Build Coastguard Worker        ]
439*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(backend_metadata.kernel, "op_2_out")
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker# Test for static_dispatch
443*da0073e9SAndroid Build Coastguard Workerclass TestStaticDispatchGeneratrion(unittest.TestCase):
444*da0073e9SAndroid Build Coastguard Worker    def setUp(self) -> None:
445*da0073e9SAndroid Build Coastguard Worker        self.backend_indices: dict[
446*da0073e9SAndroid Build Coastguard Worker            DispatchKey, dict[OperatorName, BackendMetadata]
447*da0073e9SAndroid Build Coastguard Worker        ] = defaultdict(dict)
448*da0073e9SAndroid Build Coastguard Worker        yaml_entry = """
449*da0073e9SAndroid Build Coastguard Worker- func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
450*da0073e9SAndroid Build Coastguard Worker  dispatch:
451*da0073e9SAndroid Build Coastguard Worker    CompositeExplicitAutograd: op
452*da0073e9SAndroid Build Coastguard Worker        """
453*da0073e9SAndroid Build Coastguard Worker        es = yaml.load(yaml_entry, Loader=LineLoader)
454*da0073e9SAndroid Build Coastguard Worker        self.one_return_func, m = NativeFunction.from_yaml(
455*da0073e9SAndroid Build Coastguard Worker            es[0], loc=Location(__file__, 1), valid_tags=set()
456*da0073e9SAndroid Build Coastguard Worker        )
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        BackendIndex.grow_index(self.backend_indices, m)
459*da0073e9SAndroid Build Coastguard Worker        dispatch_key = DispatchKey.CompositeExplicitAutograd
460*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(dispatch_key in self.backend_indices)
461*da0073e9SAndroid Build Coastguard Worker        self.indices = [
462*da0073e9SAndroid Build Coastguard Worker            BackendIndex(
463*da0073e9SAndroid Build Coastguard Worker                dispatch_key=dispatch_key,
464*da0073e9SAndroid Build Coastguard Worker                use_out_as_primary=True,
465*da0073e9SAndroid Build Coastguard Worker                external=False,
466*da0073e9SAndroid Build Coastguard Worker                device_guard=False,
467*da0073e9SAndroid Build Coastguard Worker                index=self.backend_indices[dispatch_key],
468*da0073e9SAndroid Build Coastguard Worker            )
469*da0073e9SAndroid Build Coastguard Worker        ]
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker    def test_op_with_1_backend_generates_static_dispatch(self) -> None:
472*da0073e9SAndroid Build Coastguard Worker        disp_sig = DispatcherSignature.from_schema(self.one_return_func.func)
473*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(self.one_return_func):
474*da0073e9SAndroid Build Coastguard Worker            out = static_dispatch(
475*da0073e9SAndroid Build Coastguard Worker                sig=disp_sig,
476*da0073e9SAndroid Build Coastguard Worker                f=self.one_return_func,
477*da0073e9SAndroid Build Coastguard Worker                backend_indices=self.indices,
478*da0073e9SAndroid Build Coastguard Worker            )
479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
480*da0073e9SAndroid Build Coastguard Worker            out, "return at::compositeexplicitautograd::op_out(out, self);"
481*da0073e9SAndroid Build Coastguard Worker        )
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker    def test_op_with_cpp_sig_generates_static_dispatch(self) -> None:
484*da0073e9SAndroid Build Coastguard Worker        sig_group = CppSignatureGroup.from_native_function(
485*da0073e9SAndroid Build Coastguard Worker            self.one_return_func,
486*da0073e9SAndroid Build Coastguard Worker            method=False,
487*da0073e9SAndroid Build Coastguard Worker            fallback_binding=self.one_return_func.manual_cpp_binding,
488*da0073e9SAndroid Build Coastguard Worker        )
489*da0073e9SAndroid Build Coastguard Worker        # cpp signature puts out at the front
490*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(self.one_return_func):
491*da0073e9SAndroid Build Coastguard Worker            out = static_dispatch(
492*da0073e9SAndroid Build Coastguard Worker                sig=sig_group.signature,
493*da0073e9SAndroid Build Coastguard Worker                f=self.one_return_func,
494*da0073e9SAndroid Build Coastguard Worker                backend_indices=self.indices,
495*da0073e9SAndroid Build Coastguard Worker            )
496*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
497*da0073e9SAndroid Build Coastguard Worker            out, "return at::compositeexplicitautograd::op_out(out, self);"
498*da0073e9SAndroid Build Coastguard Worker        )
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker# Represents the most basic NativeFunction. Use dataclasses.replace()
502*da0073e9SAndroid Build Coastguard Worker# to edit for use.
503*da0073e9SAndroid Build Coastguard WorkerDEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml(
504*da0073e9SAndroid Build Coastguard Worker    {"func": "func() -> bool"},
505*da0073e9SAndroid Build Coastguard Worker    loc=Location(__file__, 1),
506*da0073e9SAndroid Build Coastguard Worker    valid_tags=set(),
507*da0073e9SAndroid Build Coastguard Worker)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
511*da0073e9SAndroid Build Coastguard Worker    unittest.main()
512