xref: /aosp_15_r20/external/executorch/exir/dialects/edge/test/test_edge_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import List, Optional
11
12import torch
13
14from executorch.exir.dialects._ops import ops
15from executorch.exir.dialects.edge._ops import (
16    _edge_dialect_info,
17    AllowedDtypeSet,
18    EdgeOpOverload,
19    FunctionDtypeConstraint,
20)
21from torch._ops import OpOverload
22from torch.library import impl, Library
23
24lib = Library("test_op", "DEF")
25
26# Fake a operator for testing.
27# This operator takes two tensors as input and returns the first one.
28lib.define("foo(Tensor self, Tensor other) -> Tensor")
29
30
31@impl(lib, "foo", "CPU")
32def foo(a, b):
33    # do nothing and return a.
34    return a
35
36
37def foo_dtype_constraint():
38    # Update the type constraint for function foo.
39    _edge_dialect_info["test_op::foo"] = {
40        "func": "foo",
41        "namespace": "edge",
42        "inherits": "test_op::foo",
43        "type_alias": {
44            "T0": [
45                "Float",
46                "Double",
47            ],
48            "T1": [
49                "Char",
50            ],
51            "T2": [
52                "Int",
53            ],
54        },
55        "type_constraint": [
56            {
57                "self": "T0",
58                "other": "T0",
59                "__ret_0": "T0",
60            },
61            {
62                "self": "T1",
63                "other": "T1",
64                "__ret_0": "T2",
65            },
66        ],
67    }
68
69
70# Fake a operator not been included by edge.yaml for testing.
71# This operator takes three tensors as input and returns the second one.
72lib.define(
73    "yaml_unincluded(Tensor self, Tensor[] other_list, Tensor? other_optional) -> Tensor[]"
74)
75
76
77@impl(lib, "yaml_unincluded", "CPU")
78def yaml_unincluded(
79    a: torch.Tensor, b: List[torch.Tensor], c: Optional[torch.Tensor]
80) -> List[torch.Tensor]:
81    # do nothing and return b.
82    return b
83
84
85class TestEdgeOps(unittest.TestCase):
86    def setUp(self) -> None:
87        self.aten_add: OpOverload = torch.ops.aten.add.Tensor
88        self.edge_add: EdgeOpOverload = ops.edge.aten.add.Tensor
89
90        foo_dtype_constraint()
91        self.edge_foo: EdgeOpOverload = ops.edge.test_op.foo.default
92
93    def test_callable_gives_same_result(self) -> None:
94        a = torch.ones(2, 3)
95        b = torch.ones(2, 3) * 2
96        c = torch.ones(2, 3) * 3
97        self.assertTrue(torch.allclose(c, self.edge_add(a, b)))
98        self.assertTrue(torch.allclose(self.edge_add(a, b), self.aten_add(a, b)))
99
100    def test_schema_name_same_as_aten_op(self) -> None:
101        self.assertEqual(self.aten_add._schema.name, self.edge_add._schema.name)
102
103    def test_edge_argument_dtype_constraints(self) -> None:
104        edge_log_softmax: OpOverload = ops.edge.aten._log_softmax.default
105        arguments = edge_log_softmax._schema.arguments
106        returns = edge_log_softmax._schema.returns
107        for arg in arguments:
108            if isinstance(arg.type, torch.TensorType):
109                self.assertTrue(isinstance(arg.allowed_types, set))
110                self.assertEqual(
111                    arg.allowed_types, {torch.float16, torch.float32, torch.float64}
112                )
113
114        for ret in returns:
115            if isinstance(ret.type, torch.TensorType):
116                self.assertTrue(isinstance(ret.allowed_types, set))
117                self.assertEqual(
118                    ret.allowed_types, {torch.float16, torch.float32, torch.float64}
119                )
120
121    def test_allowed_dtype_set(self) -> None:
122        allowed_dtype_set = AllowedDtypeSet({torch.int8, torch.int32})
123        self.assertTrue(torch.int8 in allowed_dtype_set)
124        self.assertTrue(torch.int32 in allowed_dtype_set)
125
126        # torch.int16 is not a legal dtype for allowed_dtype_set
127        self.assertFalse(allowed_dtype_set.reduce_to(torch.int16))
128        self.assertTrue(allowed_dtype_set.reduce_to(torch.int32))
129
130        # now allowed_dtype_set is reduced to torch.int32
131        self.assertFalse(torch.int8 in allowed_dtype_set)
132        self.assertTrue(torch.int32 in allowed_dtype_set)
133
134        # clear it to make it back
135        allowed_dtype_set.clear()
136        self.assertTrue(torch.int8 in allowed_dtype_set)
137        self.assertTrue(torch.int32 in allowed_dtype_set)
138
139    def test_edge_add_dtype_constraints_content(self) -> None:
140        edge_foo_schema = self.edge_foo._schema
141        self.assertTrue(
142            isinstance(edge_foo_schema.dtype_constraint, FunctionDtypeConstraint)
143        )
144        self.assertTrue(isinstance(edge_foo_schema.dtype_constraint.type_alias, dict))
145        for key, value in edge_foo_schema.dtype_constraint.type_alias.items():
146            self.assertTrue(key in ["T0", "T1", "T2"])
147            self.assertTrue(isinstance(value, AllowedDtypeSet))
148            if key == "T0":
149                self.assertEqual(value.types, {torch.float32, torch.float64})
150            elif key == "T1":
151                self.assertEqual(
152                    value.types,
153                    {
154                        torch.int8,
155                    },
156                )
157            elif key == "T2":
158                self.assertEqual(
159                    value.types,
160                    {
161                        torch.int32,
162                    },
163                )
164
165        self.assertEqual(
166            edge_foo_schema.dtype_constraint.type_constraint,
167            [
168                {
169                    "self": "T0",
170                    "other": "T0",
171                    "__ret_0": "T0",
172                },
173                {
174                    "self": "T1",
175                    "other": "T1",
176                    "__ret_0": "T2",
177                },
178            ],
179        )
180
181    def test_edge_op_dtype_constraints_validation_function(self) -> None:
182        edge_foo_schema = self.edge_foo._schema
183        self.assertTrue(
184            edge_foo_schema.dtype_constraint.validate(
185                {
186                    "self": torch.float32,
187                    "other": torch.float32,
188                    "__ret_0": torch.float32,
189                }
190            )
191        )
192        self.assertFalse(
193            edge_foo_schema.dtype_constraint.validate(
194                {
195                    "self": torch.float32,
196                    "other": torch.float32,
197                    "__ret_0": torch.float64,
198                }
199            )
200        )
201        self.assertFalse(
202            edge_foo_schema.dtype_constraint.validate(
203                {
204                    "self": torch.float32,
205                    "other": torch.float32,
206                }
207            )
208        )
209        self.assertFalse(
210            edge_foo_schema.dtype_constraint.validate(
211                {
212                    "other": torch.float32,
213                    "__ret_0": torch.float32,
214                }
215            )
216        )
217        self.assertTrue(
218            edge_foo_schema.dtype_constraint.validate(
219                {"self": torch.int8, "other": torch.int8, "__ret_0": torch.int32}
220            )
221        )
222
223        self.assertFalse(
224            edge_foo_schema.dtype_constraint.validate(
225                {"self": torch.int8, "other": torch.int8, "__ret": torch.int32}
226            )
227        )
228
229        self.assertFalse(
230            edge_foo_schema.dtype_constraint.validate(
231                {"self": torch.int8, "other": torch.int8, "__ret_0": torch.int8}
232            )
233        )
234
235    def test_edge_op_dtype_constraints_validation_function_with_optional_tensor_input(
236        self,
237    ) -> None:
238        edge_native_layer_norm = ops.edge.aten.native_layer_norm.default
239        edge_native_layer_norm_schema = edge_native_layer_norm._schema
240        # In native layer norm, there have six tensor inputs and outputs, but weight
241        # and bias are all optional. Therefore, the dtype validator should return True
242        # if user does not provide the corresponding argument, or provide optional
243        # argument in correct dtype.
244
245        self.assertTrue(
246            edge_native_layer_norm_schema.dtype_constraint.validate(
247                {
248                    "input": torch.float32,
249                    "__ret_0": torch.float32,
250                    "__ret_1": torch.float32,
251                    "__ret_2": torch.float32,
252                }
253            )
254        )
255
256        self.assertTrue(
257            edge_native_layer_norm_schema.dtype_constraint.validate(
258                {
259                    "input": torch.float32,
260                    "weight": torch.float32,
261                    "__ret_0": torch.float32,
262                    "__ret_1": torch.float32,
263                    "__ret_2": torch.float32,
264                }
265            )
266        )
267
268        self.assertTrue(
269            edge_native_layer_norm_schema.dtype_constraint.validate(
270                {
271                    "input": torch.float32,
272                    "bias": torch.float32,
273                    "__ret_0": torch.float32,
274                    "__ret_1": torch.float32,
275                    "__ret_2": torch.float32,
276                }
277            )
278        )
279
280        self.assertTrue(
281            edge_native_layer_norm_schema.dtype_constraint.validate(
282                {
283                    "input": torch.float32,
284                    "weight": torch.float32,
285                    "bias": torch.float32,
286                    "__ret_0": torch.float32,
287                    "__ret_1": torch.float32,
288                    "__ret_2": torch.float32,
289                }
290            )
291        )
292
293        self.assertFalse(
294            edge_native_layer_norm_schema.dtype_constraint.validate(
295                {
296                    "input": torch.float32,
297                    "weight": torch.float32,
298                    "bias": torch.int32,
299                    "__ret_0": torch.float32,
300                    "__ret_1": torch.float32,
301                    "__ret_2": torch.float32,
302                }
303            )
304        )
305
306        # Any other tensor input/output should be essential input/output.
307        # The dtype validator should return False if user does not forward all essential inputs.
308        self.assertFalse(
309            edge_native_layer_norm_schema.dtype_constraint.validate(
310                {
311                    "weight": torch.float32,
312                    "bias": torch.float32,
313                    "__ret_0": torch.float32,
314                    "__ret_1": torch.float32,
315                    "__ret_2": torch.float32,
316                }
317            )
318        )
319
320        self.assertFalse(
321            edge_native_layer_norm_schema.dtype_constraint.validate(
322                {
323                    "weight": torch.float32,
324                    "bias": torch.float32,
325                    "__ret_0": torch.float32,
326                    "__ret_1": torch.float32,
327                }
328            )
329        )
330
331    def test_edge_op_dtype_constraints_validation_function_with_tensor_list_input(
332        self,
333    ) -> None:
334        edge_cat = ops.edge.aten.cat.default
335        edge_cat_schema = edge_cat._schema
336        # The input of cat, `tensors`, is a tensor list.
337        # Test if edge dialect can validate the correctness of tensor list type.
338
339        self.assertTrue(
340            edge_cat_schema.dtype_constraint.validate(
341                {"tensors": torch.float32, "__ret_0": torch.float32}
342            )
343        )
344        self.assertTrue(
345            edge_cat_schema.dtype_constraint.validate(
346                {"tensors": torch.half, "__ret_0": torch.half}
347            )
348        )
349        self.assertFalse(
350            edge_cat_schema.dtype_constraint.validate(
351                {"tensors": torch.half, "__ret_0": torch.float}
352            )
353        )
354        self.assertFalse(
355            edge_cat_schema.dtype_constraint.validate(
356                {"tensors": torch.half, "non-sense": torch.half}
357            )
358        )
359
360    def test_op_not_included_by_yaml(self) -> None:
361        # We should support operator not listed in edge.yaml
362        # For such function, any given dtype combinations will be legal as long as:
363        # a. each dtype is supported by executorch
364        # b. all essential tensor-like inputs are provided
365        # c. provided inputs rather than essential tensor-like inputs are optional tensor-like inputs.
366        edge_op_test = ops.edge.test_op.yaml_unincluded.default
367        edge_op_test_schema = edge_op_test._schema
368        self.assertTrue(
369            edge_op_test_schema.dtype_constraint.validate(
370                {
371                    "self": torch.float32,
372                    "other_list": torch.float32,
373                    "other_optional": torch.float32,
374                    "__ret_0": torch.float32,
375                }
376            )
377        )
378        self.assertTrue(
379            edge_op_test_schema.dtype_constraint.validate(
380                {
381                    "self": torch.float32,
382                    "other_list": torch.int32,
383                    "other_optional": torch.int8,
384                    "__ret_0": torch.int32,
385                }
386            )
387        )
388        self.assertTrue(
389            edge_op_test_schema.dtype_constraint.validate(
390                {
391                    "self": torch.float32,
392                    "other_list": torch.int32,
393                    "other_optional": torch.int8,
394                    "__ret_0": torch.bool,
395                }
396            )
397        )
398        self.assertTrue(
399            edge_op_test_schema.dtype_constraint.validate(
400                {
401                    "self": torch.float32,
402                    "other_list": torch.float32,
403                    "__ret_0": torch.float32,
404                }
405            )
406        )
407        self.assertFalse(
408            edge_op_test_schema.dtype_constraint.validate(
409                {
410                    "self": torch.float32,
411                    "other_optional": torch.float32,
412                    "__ret_0": torch.float32,
413                }
414            )
415        )
416
417    def test_to_out_variant_returns_correct_op(self) -> None:
418        out = self.edge_add.to_out_variant()
419        self.assertEqual(out, torch.ops.aten.add.out)
420
421    def test_to_out_variant_raises_exception_when_no_out_variant(self) -> None:
422        view_op = ops.edge.aten.view.default
423        with self.assertRaisesRegex(
424            RuntimeError,
425            "SchemaKind.out variant of operator aten::view can't be found.",
426        ):
427            view_op.to_out_variant()
428
429    def test_get_new_registered_out_var(
430        self,
431    ) -> None:
432        library = Library("TEST_ONLY", "DEF")
433        library.define("foo.Tensor(Tensor a, Tensor b) -> Tensor")
434        op = ops.edge.TEST_ONLY.foo.Tensor
435
436        self.assertRaises(RuntimeError, op.to_out_variant)
437        library.define(
438            "foo.Tensor_out(Tensor a, Tensor b, *, Tensor(a!) out) -> Tensor(a!)"
439        )
440        out = op.to_out_variant()
441        self.assertEqual(out, torch.ops.TEST_ONLY.foo.Tensor_out)
442