xref: /aosp_15_r20/external/executorch/exir/dialects/edge/test/test_edge_yaml.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8import unittest
9from typing import Any, Dict, List, Set
10
11import torch
12from executorch.exir.dialects.edge.dtype.supported import regular_tensor_dtypes_to_str
13
14from executorch.exir.dialects.edge.spec.gen import (
15    EdgeOpYamlInfo,
16    gen_op_yaml,
17    get_sample_input,
18)
19
20
21class TestEdgeYaml(unittest.TestCase):
22    def assertTypeAliasEqual(
23        self, type_alias_1: List[List[str]], type_alias_2: List[List[str]]
24    ) -> None:
25        """Helper function to assert two type alias equal"""
26        self.assertEqual(len(type_alias_1), len(type_alias_2))
27        type_alias_set_1: List[Set[str]] = []
28        type_alias_set_2: List[Set[str]] = []
29        for ta1, ta2 in zip(type_alias_1, type_alias_2):
30            self.assertEqual(len(ta1), len(set(ta1)))
31            self.assertEqual(len(ta2), len(set(ta2)))
32            type_alias_set_1.append(set(ta1))
33            type_alias_set_2.append(set(ta2))
34
35        for tas1, tas2 in zip(type_alias_set_1, type_alias_set_2):
36            self.assertTrue(tas1 in type_alias_set_2)
37            self.assertTrue(tas2 in type_alias_set_1)
38
39    def assertOpYamlEqual(
40        self, op_yaml_1: Dict[str, Any], op_yaml_2: Dict[str, Any]
41    ) -> None:
42        """Helper function to assert two edge operator yaml object equal"""
43
44        for op_yaml_key in op_yaml_1:
45            self.assertTrue(op_yaml_key in op_yaml_2)
46            if op_yaml_key == "type_alias":
47                self.assertEqual(
48                    len(op_yaml_1[op_yaml_key]), len(op_yaml_2[op_yaml_key])
49                )
50                type_alias_list_1: List[List[str]] = []
51                type_alias_list_2: List[List[str]] = []
52                for type_alias_key in op_yaml_1[op_yaml_key]:
53                    self.assertTrue(type_alias_key in op_yaml_2[op_yaml_key])
54                    type_alias_list_1.append(op_yaml_1[op_yaml_key][type_alias_key])
55                    type_alias_list_2.append(op_yaml_2[op_yaml_key][type_alias_key])
56
57                self.assertTypeAliasEqual(type_alias_list_1, type_alias_list_2)
58            else:
59                self.assertEqual(op_yaml_1[op_yaml_key], op_yaml_2[op_yaml_key])
60
61        self.assertEqual(op_yaml_1["func"], op_yaml_2["func"])
62        self.assertEqual(op_yaml_1["namespace"], op_yaml_2["namespace"])
63
64    def assertEdgeYamlEqual(
65        self, edge_yaml_1: List[Dict[str, Any]], edge_yaml_2: List[Dict[str, Any]]
66    ) -> None:
67        """Helper function to assert two edge dialect yaml object equal"""
68        self.assertEqual(len(edge_yaml_1), len(edge_yaml_2))
69        dict_edge_yaml_1: Dict[str, Dict[str, Any]] = {
70            op["func"]: op for op in edge_yaml_1
71        }
72        dict_edge_yaml_2: Dict[str, Dict[str, Any]] = {
73            op["func"]: op for op in edge_yaml_2
74        }
75
76        for op_yaml_key in dict_edge_yaml_1:
77            assert op_yaml_key in dict_edge_yaml_2
78            op_yaml_1, op_yaml_2 = (
79                dict_edge_yaml_1[op_yaml_key],
80                dict_edge_yaml_2[op_yaml_key],
81            )
82            self.assertOpYamlEqual(op_yaml_1, op_yaml_2)
83
84    def test_edge_op_yaml_info_combine_types_with_all_same_types(self) -> None:
85        """This test aims to check if EdgeOpYamlInfo can a. generate correct type
86        alias and type constraint and b. properly combine the type combinations with
87        all same input types (e.g. (FloatTensor, FloatTensor, FloatTensor),
88        (DoubleTensor, DoubleTensor, DoubleTensor)).
89        """
90
91        example_yaml_info = EdgeOpYamlInfo(
92            func_name="add.Tensor",
93            tensor_variable_names=["self", "other", "__ret"],
94            inherits="aten::add.Tensor",
95            allowed_types={
96                ("Float", "Float", "Float"),
97                ("Double", "Double", "Double"),
98                ("Char", "Char", "Int"),
99            },
100        )
101
102        self.assertEqual(example_yaml_info.func_name, "add.Tensor")
103        self.assertEqual(
104            example_yaml_info.tensor_variable_names, ["self", "other", "__ret"]
105        )
106        self.assertEqual(example_yaml_info.inherits, "aten::add.Tensor")
107        self.assertEqual(example_yaml_info.custom, "")
108        self.assertEqual(
109            example_yaml_info.type_alias,
110            [("Char",), ("Double", "Float"), ("Int",)],
111        )
112        self.assertEqual(example_yaml_info.type_constraint, [(0, 0, 2), (1, 1, 1)])
113
114    def test_edge_op_yaml_info_combine_same_format(self) -> None:
115        """This test aims to check if EdgeOpYamlInfo can a. generate correct type
116        alias and type constraint and b. properly combine the inputs with same format.
117        Two inputs having same format here means one and only one of their corresponding
118        input tensors is different. e.g. {DoubleTensor, DoubleTensor), FloatTensor}
119        shares same format with {DoubleTensor, DoubleTensor, DoubleTensor},
120        but not {DoubleTensor, FloatTensor, DoubleTensor}.
121
122        """
123
124        example_yaml_info = EdgeOpYamlInfo(
125            func_name="tanh",
126            tensor_variable_names=["self", "__ret_0"],
127            inherits="aten::tanh",
128            allowed_types={
129                ("Bool", "Float"),
130                ("Byte", "Float"),
131                ("Char", "Float"),
132                ("Short", "Float"),
133                ("Int", "Float"),
134                ("Long", "Int"),
135                ("Float", "Float"),
136                ("Double", "Double"),
137            },
138        )
139
140        self.assertEqual(example_yaml_info.func_name, "tanh")
141        self.assertEqual(example_yaml_info.tensor_variable_names, ["self", "__ret_0"])
142        self.assertEqual(example_yaml_info.inherits, "aten::tanh")
143        self.assertEqual(example_yaml_info.custom, "")
144        self.assertEqual(
145            example_yaml_info.type_alias,
146            [
147                ("Bool", "Byte", "Char", "Float", "Int", "Short"),
148                ("Double",),
149                ("Float",),
150                ("Int",),
151                ("Long",),
152            ],
153        )
154        self.assertEqual(example_yaml_info.type_constraint, [(0, 2), (1, 1), (4, 3)])
155
156    def test_optional_tensor_supported(self) -> None:
157        # Two of three tensor inputs of native_layer_norm are in optional tensor type.
158        ret = gen_op_yaml("native_layer_norm.default")
159        self.assertTrue(ret is not None)
160        self.assertEqual(ret.func_name, "aten::native_layer_norm")
161        self.assertEqual(ret.inherits, "aten::native_layer_norm")
162        self.assertEqual(ret.custom, "")
163        self.assertEqual(ret.type_alias, [("Double", "Float", "Half")])
164        self.assertEqual(ret.type_constraint, [(0, 0, 0, 0, 0, 0)])
165        self.assertEqual(
166            ret.tensor_variable_names,
167            ["input", "weight", "bias", "__ret_0", "__ret_1", "__ret_2"],
168        )
169
170    def test_tensor_list_supported(self) -> None:
171        # Input of cat is tensor list.
172        ret = gen_op_yaml("cat.default")
173        self.assertTrue(ret is not None)
174        self.assertEqual(ret.func_name, "aten::cat")
175        self.assertEqual(ret.inherits, "aten::cat")
176        self.assertEqual(ret.custom, "")
177        self.assertEqual(
178            ret.type_alias,
179            [
180                (
181                    "Bool",
182                    "Byte",
183                    "Char",
184                    "Double",
185                    "Float",
186                    "Half",
187                    "Int",
188                    "Long",
189                    "Short",
190                )
191            ],
192        )
193        self.assertEqual(ret.type_constraint, [(0, 0)])
194        self.assertEqual(ret.tensor_variable_names, ["tensors", "__ret_0"])
195
196    # Check if any function updated by comparing the current yaml file with
197    # previous one. If anything mismatch, please follow the instructions at the
198    # top of //executorch/exir/dialects/edge/edge.yaml.
199    # TODO(gasoonjia, T159593834): Should be updated after support other models and infer methods.
200    # def test_need_update_edge_yaml(self) -> None:
201    #     model = <need OSS model example>
202    #     model_edge_dialect_operators: List[str] = get_all_ops(model)
203    #     with tempfile.NamedTemporaryFile(mode="w+") as yaml_stream:
204    #         _ = gen_edge_yaml(model_edge_dialect_operators, yaml_stream)
205    #         yaml_stream.seek(0, 0)
206    #         self.assertTrue(
207    #             filecmp.cmp(
208    #                 yaml_stream.name,
209    #                 "executorch/exir/dialects/edge/edge.yaml",
210    #             ),
211    #             "Please run `//executorch/exir/dialects/edge:yaml_generator -- --regenerate` to regenerate the file.",
212    #         )
213
214    def test_to_copy_sample_input_has_enough_coverage(self) -> None:
215        """Make sure sample input to _to_copy(Tensor self, *, ScalarType dtype, ...) has enough coverage"""
216        sample_input = get_sample_input(
217            key="to", overload_name="", edge_type=torch.float32
218        )
219        dtype_set: Set[torch.dtype] = set()
220        for _, kwargs in sample_input:
221            self.assertTrue("dtype" in kwargs)
222            dtype_set.add(kwargs["dtype"])
223
224        self.assertTrue(dtype_set == regular_tensor_dtypes_to_str.keys())
225