1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import List, Optional 11 12import torch 13 14from executorch.exir.dialects._ops import ops 15from executorch.exir.dialects.edge._ops import ( 16 _edge_dialect_info, 17 AllowedDtypeSet, 18 EdgeOpOverload, 19 FunctionDtypeConstraint, 20) 21from torch._ops import OpOverload 22from torch.library import impl, Library 23 24lib = Library("test_op", "DEF") 25 26# Fake a operator for testing. 27# This operator takes two tensors as input and returns the first one. 28lib.define("foo(Tensor self, Tensor other) -> Tensor") 29 30 31@impl(lib, "foo", "CPU") 32def foo(a, b): 33 # do nothing and return a. 34 return a 35 36 37def foo_dtype_constraint(): 38 # Update the type constraint for function foo. 39 _edge_dialect_info["test_op::foo"] = { 40 "func": "foo", 41 "namespace": "edge", 42 "inherits": "test_op::foo", 43 "type_alias": { 44 "T0": [ 45 "Float", 46 "Double", 47 ], 48 "T1": [ 49 "Char", 50 ], 51 "T2": [ 52 "Int", 53 ], 54 }, 55 "type_constraint": [ 56 { 57 "self": "T0", 58 "other": "T0", 59 "__ret_0": "T0", 60 }, 61 { 62 "self": "T1", 63 "other": "T1", 64 "__ret_0": "T2", 65 }, 66 ], 67 } 68 69 70# Fake a operator not been included by edge.yaml for testing. 71# This operator takes three tensors as input and returns the second one. 72lib.define( 73 "yaml_unincluded(Tensor self, Tensor[] other_list, Tensor? other_optional) -> Tensor[]" 74) 75 76 77@impl(lib, "yaml_unincluded", "CPU") 78def yaml_unincluded( 79 a: torch.Tensor, b: List[torch.Tensor], c: Optional[torch.Tensor] 80) -> List[torch.Tensor]: 81 # do nothing and return b. 82 return b 83 84 85class TestEdgeOps(unittest.TestCase): 86 def setUp(self) -> None: 87 self.aten_add: OpOverload = torch.ops.aten.add.Tensor 88 self.edge_add: EdgeOpOverload = ops.edge.aten.add.Tensor 89 90 foo_dtype_constraint() 91 self.edge_foo: EdgeOpOverload = ops.edge.test_op.foo.default 92 93 def test_callable_gives_same_result(self) -> None: 94 a = torch.ones(2, 3) 95 b = torch.ones(2, 3) * 2 96 c = torch.ones(2, 3) * 3 97 self.assertTrue(torch.allclose(c, self.edge_add(a, b))) 98 self.assertTrue(torch.allclose(self.edge_add(a, b), self.aten_add(a, b))) 99 100 def test_schema_name_same_as_aten_op(self) -> None: 101 self.assertEqual(self.aten_add._schema.name, self.edge_add._schema.name) 102 103 def test_edge_argument_dtype_constraints(self) -> None: 104 edge_log_softmax: OpOverload = ops.edge.aten._log_softmax.default 105 arguments = edge_log_softmax._schema.arguments 106 returns = edge_log_softmax._schema.returns 107 for arg in arguments: 108 if isinstance(arg.type, torch.TensorType): 109 self.assertTrue(isinstance(arg.allowed_types, set)) 110 self.assertEqual( 111 arg.allowed_types, {torch.float16, torch.float32, torch.float64} 112 ) 113 114 for ret in returns: 115 if isinstance(ret.type, torch.TensorType): 116 self.assertTrue(isinstance(ret.allowed_types, set)) 117 self.assertEqual( 118 ret.allowed_types, {torch.float16, torch.float32, torch.float64} 119 ) 120 121 def test_allowed_dtype_set(self) -> None: 122 allowed_dtype_set = AllowedDtypeSet({torch.int8, torch.int32}) 123 self.assertTrue(torch.int8 in allowed_dtype_set) 124 self.assertTrue(torch.int32 in allowed_dtype_set) 125 126 # torch.int16 is not a legal dtype for allowed_dtype_set 127 self.assertFalse(allowed_dtype_set.reduce_to(torch.int16)) 128 self.assertTrue(allowed_dtype_set.reduce_to(torch.int32)) 129 130 # now allowed_dtype_set is reduced to torch.int32 131 self.assertFalse(torch.int8 in allowed_dtype_set) 132 self.assertTrue(torch.int32 in allowed_dtype_set) 133 134 # clear it to make it back 135 allowed_dtype_set.clear() 136 self.assertTrue(torch.int8 in allowed_dtype_set) 137 self.assertTrue(torch.int32 in allowed_dtype_set) 138 139 def test_edge_add_dtype_constraints_content(self) -> None: 140 edge_foo_schema = self.edge_foo._schema 141 self.assertTrue( 142 isinstance(edge_foo_schema.dtype_constraint, FunctionDtypeConstraint) 143 ) 144 self.assertTrue(isinstance(edge_foo_schema.dtype_constraint.type_alias, dict)) 145 for key, value in edge_foo_schema.dtype_constraint.type_alias.items(): 146 self.assertTrue(key in ["T0", "T1", "T2"]) 147 self.assertTrue(isinstance(value, AllowedDtypeSet)) 148 if key == "T0": 149 self.assertEqual(value.types, {torch.float32, torch.float64}) 150 elif key == "T1": 151 self.assertEqual( 152 value.types, 153 { 154 torch.int8, 155 }, 156 ) 157 elif key == "T2": 158 self.assertEqual( 159 value.types, 160 { 161 torch.int32, 162 }, 163 ) 164 165 self.assertEqual( 166 edge_foo_schema.dtype_constraint.type_constraint, 167 [ 168 { 169 "self": "T0", 170 "other": "T0", 171 "__ret_0": "T0", 172 }, 173 { 174 "self": "T1", 175 "other": "T1", 176 "__ret_0": "T2", 177 }, 178 ], 179 ) 180 181 def test_edge_op_dtype_constraints_validation_function(self) -> None: 182 edge_foo_schema = self.edge_foo._schema 183 self.assertTrue( 184 edge_foo_schema.dtype_constraint.validate( 185 { 186 "self": torch.float32, 187 "other": torch.float32, 188 "__ret_0": torch.float32, 189 } 190 ) 191 ) 192 self.assertFalse( 193 edge_foo_schema.dtype_constraint.validate( 194 { 195 "self": torch.float32, 196 "other": torch.float32, 197 "__ret_0": torch.float64, 198 } 199 ) 200 ) 201 self.assertFalse( 202 edge_foo_schema.dtype_constraint.validate( 203 { 204 "self": torch.float32, 205 "other": torch.float32, 206 } 207 ) 208 ) 209 self.assertFalse( 210 edge_foo_schema.dtype_constraint.validate( 211 { 212 "other": torch.float32, 213 "__ret_0": torch.float32, 214 } 215 ) 216 ) 217 self.assertTrue( 218 edge_foo_schema.dtype_constraint.validate( 219 {"self": torch.int8, "other": torch.int8, "__ret_0": torch.int32} 220 ) 221 ) 222 223 self.assertFalse( 224 edge_foo_schema.dtype_constraint.validate( 225 {"self": torch.int8, "other": torch.int8, "__ret": torch.int32} 226 ) 227 ) 228 229 self.assertFalse( 230 edge_foo_schema.dtype_constraint.validate( 231 {"self": torch.int8, "other": torch.int8, "__ret_0": torch.int8} 232 ) 233 ) 234 235 def test_edge_op_dtype_constraints_validation_function_with_optional_tensor_input( 236 self, 237 ) -> None: 238 edge_native_layer_norm = ops.edge.aten.native_layer_norm.default 239 edge_native_layer_norm_schema = edge_native_layer_norm._schema 240 # In native layer norm, there have six tensor inputs and outputs, but weight 241 # and bias are all optional. Therefore, the dtype validator should return True 242 # if user does not provide the corresponding argument, or provide optional 243 # argument in correct dtype. 244 245 self.assertTrue( 246 edge_native_layer_norm_schema.dtype_constraint.validate( 247 { 248 "input": torch.float32, 249 "__ret_0": torch.float32, 250 "__ret_1": torch.float32, 251 "__ret_2": torch.float32, 252 } 253 ) 254 ) 255 256 self.assertTrue( 257 edge_native_layer_norm_schema.dtype_constraint.validate( 258 { 259 "input": torch.float32, 260 "weight": torch.float32, 261 "__ret_0": torch.float32, 262 "__ret_1": torch.float32, 263 "__ret_2": torch.float32, 264 } 265 ) 266 ) 267 268 self.assertTrue( 269 edge_native_layer_norm_schema.dtype_constraint.validate( 270 { 271 "input": torch.float32, 272 "bias": torch.float32, 273 "__ret_0": torch.float32, 274 "__ret_1": torch.float32, 275 "__ret_2": torch.float32, 276 } 277 ) 278 ) 279 280 self.assertTrue( 281 edge_native_layer_norm_schema.dtype_constraint.validate( 282 { 283 "input": torch.float32, 284 "weight": torch.float32, 285 "bias": torch.float32, 286 "__ret_0": torch.float32, 287 "__ret_1": torch.float32, 288 "__ret_2": torch.float32, 289 } 290 ) 291 ) 292 293 self.assertFalse( 294 edge_native_layer_norm_schema.dtype_constraint.validate( 295 { 296 "input": torch.float32, 297 "weight": torch.float32, 298 "bias": torch.int32, 299 "__ret_0": torch.float32, 300 "__ret_1": torch.float32, 301 "__ret_2": torch.float32, 302 } 303 ) 304 ) 305 306 # Any other tensor input/output should be essential input/output. 307 # The dtype validator should return False if user does not forward all essential inputs. 308 self.assertFalse( 309 edge_native_layer_norm_schema.dtype_constraint.validate( 310 { 311 "weight": torch.float32, 312 "bias": torch.float32, 313 "__ret_0": torch.float32, 314 "__ret_1": torch.float32, 315 "__ret_2": torch.float32, 316 } 317 ) 318 ) 319 320 self.assertFalse( 321 edge_native_layer_norm_schema.dtype_constraint.validate( 322 { 323 "weight": torch.float32, 324 "bias": torch.float32, 325 "__ret_0": torch.float32, 326 "__ret_1": torch.float32, 327 } 328 ) 329 ) 330 331 def test_edge_op_dtype_constraints_validation_function_with_tensor_list_input( 332 self, 333 ) -> None: 334 edge_cat = ops.edge.aten.cat.default 335 edge_cat_schema = edge_cat._schema 336 # The input of cat, `tensors`, is a tensor list. 337 # Test if edge dialect can validate the correctness of tensor list type. 338 339 self.assertTrue( 340 edge_cat_schema.dtype_constraint.validate( 341 {"tensors": torch.float32, "__ret_0": torch.float32} 342 ) 343 ) 344 self.assertTrue( 345 edge_cat_schema.dtype_constraint.validate( 346 {"tensors": torch.half, "__ret_0": torch.half} 347 ) 348 ) 349 self.assertFalse( 350 edge_cat_schema.dtype_constraint.validate( 351 {"tensors": torch.half, "__ret_0": torch.float} 352 ) 353 ) 354 self.assertFalse( 355 edge_cat_schema.dtype_constraint.validate( 356 {"tensors": torch.half, "non-sense": torch.half} 357 ) 358 ) 359 360 def test_op_not_included_by_yaml(self) -> None: 361 # We should support operator not listed in edge.yaml 362 # For such function, any given dtype combinations will be legal as long as: 363 # a. each dtype is supported by executorch 364 # b. all essential tensor-like inputs are provided 365 # c. provided inputs rather than essential tensor-like inputs are optional tensor-like inputs. 366 edge_op_test = ops.edge.test_op.yaml_unincluded.default 367 edge_op_test_schema = edge_op_test._schema 368 self.assertTrue( 369 edge_op_test_schema.dtype_constraint.validate( 370 { 371 "self": torch.float32, 372 "other_list": torch.float32, 373 "other_optional": torch.float32, 374 "__ret_0": torch.float32, 375 } 376 ) 377 ) 378 self.assertTrue( 379 edge_op_test_schema.dtype_constraint.validate( 380 { 381 "self": torch.float32, 382 "other_list": torch.int32, 383 "other_optional": torch.int8, 384 "__ret_0": torch.int32, 385 } 386 ) 387 ) 388 self.assertTrue( 389 edge_op_test_schema.dtype_constraint.validate( 390 { 391 "self": torch.float32, 392 "other_list": torch.int32, 393 "other_optional": torch.int8, 394 "__ret_0": torch.bool, 395 } 396 ) 397 ) 398 self.assertTrue( 399 edge_op_test_schema.dtype_constraint.validate( 400 { 401 "self": torch.float32, 402 "other_list": torch.float32, 403 "__ret_0": torch.float32, 404 } 405 ) 406 ) 407 self.assertFalse( 408 edge_op_test_schema.dtype_constraint.validate( 409 { 410 "self": torch.float32, 411 "other_optional": torch.float32, 412 "__ret_0": torch.float32, 413 } 414 ) 415 ) 416 417 def test_to_out_variant_returns_correct_op(self) -> None: 418 out = self.edge_add.to_out_variant() 419 self.assertEqual(out, torch.ops.aten.add.out) 420 421 def test_to_out_variant_raises_exception_when_no_out_variant(self) -> None: 422 view_op = ops.edge.aten.view.default 423 with self.assertRaisesRegex( 424 RuntimeError, 425 "SchemaKind.out variant of operator aten::view can't be found.", 426 ): 427 view_op.to_out_variant() 428 429 def test_get_new_registered_out_var( 430 self, 431 ) -> None: 432 library = Library("TEST_ONLY", "DEF") 433 library.define("foo.Tensor(Tensor a, Tensor b) -> Tensor") 434 op = ops.edge.TEST_ONLY.foo.Tensor 435 436 self.assertRaises(RuntimeError, op.to_out_variant) 437 library.define( 438 "foo.Tensor_out(Tensor a, Tensor b, *, Tensor(a!) out) -> Tensor(a!)" 439 ) 440 out = op.to_out_variant() 441 self.assertEqual(out, torch.ops.TEST_ONLY.foo.Tensor_out) 442