1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8import unittest 9from typing import Any, Dict, List, Set 10 11import torch 12from executorch.exir.dialects.edge.dtype.supported import regular_tensor_dtypes_to_str 13 14from executorch.exir.dialects.edge.spec.gen import ( 15 EdgeOpYamlInfo, 16 gen_op_yaml, 17 get_sample_input, 18) 19 20 21class TestEdgeYaml(unittest.TestCase): 22 def assertTypeAliasEqual( 23 self, type_alias_1: List[List[str]], type_alias_2: List[List[str]] 24 ) -> None: 25 """Helper function to assert two type alias equal""" 26 self.assertEqual(len(type_alias_1), len(type_alias_2)) 27 type_alias_set_1: List[Set[str]] = [] 28 type_alias_set_2: List[Set[str]] = [] 29 for ta1, ta2 in zip(type_alias_1, type_alias_2): 30 self.assertEqual(len(ta1), len(set(ta1))) 31 self.assertEqual(len(ta2), len(set(ta2))) 32 type_alias_set_1.append(set(ta1)) 33 type_alias_set_2.append(set(ta2)) 34 35 for tas1, tas2 in zip(type_alias_set_1, type_alias_set_2): 36 self.assertTrue(tas1 in type_alias_set_2) 37 self.assertTrue(tas2 in type_alias_set_1) 38 39 def assertOpYamlEqual( 40 self, op_yaml_1: Dict[str, Any], op_yaml_2: Dict[str, Any] 41 ) -> None: 42 """Helper function to assert two edge operator yaml object equal""" 43 44 for op_yaml_key in op_yaml_1: 45 self.assertTrue(op_yaml_key in op_yaml_2) 46 if op_yaml_key == "type_alias": 47 self.assertEqual( 48 len(op_yaml_1[op_yaml_key]), len(op_yaml_2[op_yaml_key]) 49 ) 50 type_alias_list_1: List[List[str]] = [] 51 type_alias_list_2: List[List[str]] = [] 52 for type_alias_key in op_yaml_1[op_yaml_key]: 53 self.assertTrue(type_alias_key in op_yaml_2[op_yaml_key]) 54 type_alias_list_1.append(op_yaml_1[op_yaml_key][type_alias_key]) 55 type_alias_list_2.append(op_yaml_2[op_yaml_key][type_alias_key]) 56 57 self.assertTypeAliasEqual(type_alias_list_1, type_alias_list_2) 58 else: 59 self.assertEqual(op_yaml_1[op_yaml_key], op_yaml_2[op_yaml_key]) 60 61 self.assertEqual(op_yaml_1["func"], op_yaml_2["func"]) 62 self.assertEqual(op_yaml_1["namespace"], op_yaml_2["namespace"]) 63 64 def assertEdgeYamlEqual( 65 self, edge_yaml_1: List[Dict[str, Any]], edge_yaml_2: List[Dict[str, Any]] 66 ) -> None: 67 """Helper function to assert two edge dialect yaml object equal""" 68 self.assertEqual(len(edge_yaml_1), len(edge_yaml_2)) 69 dict_edge_yaml_1: Dict[str, Dict[str, Any]] = { 70 op["func"]: op for op in edge_yaml_1 71 } 72 dict_edge_yaml_2: Dict[str, Dict[str, Any]] = { 73 op["func"]: op for op in edge_yaml_2 74 } 75 76 for op_yaml_key in dict_edge_yaml_1: 77 assert op_yaml_key in dict_edge_yaml_2 78 op_yaml_1, op_yaml_2 = ( 79 dict_edge_yaml_1[op_yaml_key], 80 dict_edge_yaml_2[op_yaml_key], 81 ) 82 self.assertOpYamlEqual(op_yaml_1, op_yaml_2) 83 84 def test_edge_op_yaml_info_combine_types_with_all_same_types(self) -> None: 85 """This test aims to check if EdgeOpYamlInfo can a. generate correct type 86 alias and type constraint and b. properly combine the type combinations with 87 all same input types (e.g. (FloatTensor, FloatTensor, FloatTensor), 88 (DoubleTensor, DoubleTensor, DoubleTensor)). 89 """ 90 91 example_yaml_info = EdgeOpYamlInfo( 92 func_name="add.Tensor", 93 tensor_variable_names=["self", "other", "__ret"], 94 inherits="aten::add.Tensor", 95 allowed_types={ 96 ("Float", "Float", "Float"), 97 ("Double", "Double", "Double"), 98 ("Char", "Char", "Int"), 99 }, 100 ) 101 102 self.assertEqual(example_yaml_info.func_name, "add.Tensor") 103 self.assertEqual( 104 example_yaml_info.tensor_variable_names, ["self", "other", "__ret"] 105 ) 106 self.assertEqual(example_yaml_info.inherits, "aten::add.Tensor") 107 self.assertEqual(example_yaml_info.custom, "") 108 self.assertEqual( 109 example_yaml_info.type_alias, 110 [("Char",), ("Double", "Float"), ("Int",)], 111 ) 112 self.assertEqual(example_yaml_info.type_constraint, [(0, 0, 2), (1, 1, 1)]) 113 114 def test_edge_op_yaml_info_combine_same_format(self) -> None: 115 """This test aims to check if EdgeOpYamlInfo can a. generate correct type 116 alias and type constraint and b. properly combine the inputs with same format. 117 Two inputs having same format here means one and only one of their corresponding 118 input tensors is different. e.g. {DoubleTensor, DoubleTensor), FloatTensor} 119 shares same format with {DoubleTensor, DoubleTensor, DoubleTensor}, 120 but not {DoubleTensor, FloatTensor, DoubleTensor}. 121 122 """ 123 124 example_yaml_info = EdgeOpYamlInfo( 125 func_name="tanh", 126 tensor_variable_names=["self", "__ret_0"], 127 inherits="aten::tanh", 128 allowed_types={ 129 ("Bool", "Float"), 130 ("Byte", "Float"), 131 ("Char", "Float"), 132 ("Short", "Float"), 133 ("Int", "Float"), 134 ("Long", "Int"), 135 ("Float", "Float"), 136 ("Double", "Double"), 137 }, 138 ) 139 140 self.assertEqual(example_yaml_info.func_name, "tanh") 141 self.assertEqual(example_yaml_info.tensor_variable_names, ["self", "__ret_0"]) 142 self.assertEqual(example_yaml_info.inherits, "aten::tanh") 143 self.assertEqual(example_yaml_info.custom, "") 144 self.assertEqual( 145 example_yaml_info.type_alias, 146 [ 147 ("Bool", "Byte", "Char", "Float", "Int", "Short"), 148 ("Double",), 149 ("Float",), 150 ("Int",), 151 ("Long",), 152 ], 153 ) 154 self.assertEqual(example_yaml_info.type_constraint, [(0, 2), (1, 1), (4, 3)]) 155 156 def test_optional_tensor_supported(self) -> None: 157 # Two of three tensor inputs of native_layer_norm are in optional tensor type. 158 ret = gen_op_yaml("native_layer_norm.default") 159 self.assertTrue(ret is not None) 160 self.assertEqual(ret.func_name, "aten::native_layer_norm") 161 self.assertEqual(ret.inherits, "aten::native_layer_norm") 162 self.assertEqual(ret.custom, "") 163 self.assertEqual(ret.type_alias, [("Double", "Float", "Half")]) 164 self.assertEqual(ret.type_constraint, [(0, 0, 0, 0, 0, 0)]) 165 self.assertEqual( 166 ret.tensor_variable_names, 167 ["input", "weight", "bias", "__ret_0", "__ret_1", "__ret_2"], 168 ) 169 170 def test_tensor_list_supported(self) -> None: 171 # Input of cat is tensor list. 172 ret = gen_op_yaml("cat.default") 173 self.assertTrue(ret is not None) 174 self.assertEqual(ret.func_name, "aten::cat") 175 self.assertEqual(ret.inherits, "aten::cat") 176 self.assertEqual(ret.custom, "") 177 self.assertEqual( 178 ret.type_alias, 179 [ 180 ( 181 "Bool", 182 "Byte", 183 "Char", 184 "Double", 185 "Float", 186 "Half", 187 "Int", 188 "Long", 189 "Short", 190 ) 191 ], 192 ) 193 self.assertEqual(ret.type_constraint, [(0, 0)]) 194 self.assertEqual(ret.tensor_variable_names, ["tensors", "__ret_0"]) 195 196 # Check if any function updated by comparing the current yaml file with 197 # previous one. If anything mismatch, please follow the instructions at the 198 # top of //executorch/exir/dialects/edge/edge.yaml. 199 # TODO(gasoonjia, T159593834): Should be updated after support other models and infer methods. 200 # def test_need_update_edge_yaml(self) -> None: 201 # model = <need OSS model example> 202 # model_edge_dialect_operators: List[str] = get_all_ops(model) 203 # with tempfile.NamedTemporaryFile(mode="w+") as yaml_stream: 204 # _ = gen_edge_yaml(model_edge_dialect_operators, yaml_stream) 205 # yaml_stream.seek(0, 0) 206 # self.assertTrue( 207 # filecmp.cmp( 208 # yaml_stream.name, 209 # "executorch/exir/dialects/edge/edge.yaml", 210 # ), 211 # "Please run `//executorch/exir/dialects/edge:yaml_generator -- --regenerate` to regenerate the file.", 212 # ) 213 214 def test_to_copy_sample_input_has_enough_coverage(self) -> None: 215 """Make sure sample input to _to_copy(Tensor self, *, ScalarType dtype, ...) has enough coverage""" 216 sample_input = get_sample_input( 217 key="to", overload_name="", edge_type=torch.float32 218 ) 219 dtype_set: Set[torch.dtype] = set() 220 for _, kwargs in sample_input: 221 self.assertTrue("dtype" in kwargs) 222 dtype_set.add(kwargs["dtype"]) 223 224 self.assertTrue(dtype_set == regular_tensor_dtypes_to_str.keys()) 225