1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import copy 10import unittest 11from typing import Dict, List, Tuple 12 13import executorch.exir as exir 14import executorch.exir.tests.models as models 15 16import torch 17 18from executorch.exir import CaptureConfig 19from executorch.exir.dialects._ops import ops as exir_ops 20from executorch.exir.tests.common import register_additional_test_aten_ops 21from executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo 22from functorch.experimental.control_flow import cond, map 23 24from parameterized import parameterized 25from torch._export.verifier import SpecViolationError 26from torch.fx.experimental.symbolic_shapes import is_concrete_int 27from torch.testing import FileCheck 28 29 30class TestTorchDispatchFXTracer(unittest.TestCase): 31 @classmethod 32 def setUpClass(cls) -> None: 33 register_additional_test_aten_ops() 34 35 def test_simple(self) -> None: 36 f = models.BasicSinMax() 37 f = ( 38 exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) 39 .to_edge() 40 .exported_program.graph_module 41 ) 42 43 FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code) 44 45 def test_static_control_flow(self) -> None: 46 def f(pred: bool, x: torch.Tensor) -> torch.Tensor: 47 if pred: 48 return torch.sin(x).max() 49 else: 50 return torch.sin(x) 51 52 pred = True 53 x = torch.randn(100) 54 f_true = ( 55 exir.capture(f, (pred, x), exir.CaptureConfig()) 56 .to_edge() 57 .exported_program.graph_module 58 ) 59 60 FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check( 61 "executorch_exir_dialects_edge__ops_aten_max" 62 ).run(f_true.code) 63 64 pred = False 65 f_false = ( 66 exir.capture(f, (pred, x), exir.CaptureConfig()) 67 .to_edge() 68 .exported_program.graph_module 69 ) 70 FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not( 71 "executorch_exir_dialects_edge__ops_aten_max" 72 ).run(f_false.code) 73 74 def test_copy(self) -> None: 75 f = models.BasicSinMax() 76 f = ( 77 exir.capture(f, f.get_random_inputs(), exir.CaptureConfig()) 78 .to_edge() 79 .exported_program.graph_module 80 ) 81 82 self.assertTrue(isinstance(f, torch.fx.GraphModule)) 83 g = copy.deepcopy(f) 84 self.assertTrue(isinstance(g, torch.fx.GraphModule)) 85 86 def test_stacktrace(self) -> None: 87 def f(x: torch.Tensor) -> torch.Tensor: 88 return x + x 89 90 traced_f = ( 91 exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig()) 92 .to_edge() 93 .exported_program.graph_module 94 ) 95 # Check that stacktrace is populated and retained (by checking twice) 96 self.assertTrue( 97 any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) 98 ) 99 self.assertTrue( 100 any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes) 101 ) 102 103 def test_ones(self) -> None: 104 class M(torch.nn.Module): 105 def forward(self, x): 106 y = torch.ones(x.shape[0]) 107 return x + y 108 109 ep = torch.export.export( 110 M(), (torch.ones(3),), dynamic_shapes={"x": {0: torch.export.Dim("x")}} 111 ) 112 exir.to_edge(ep) 113 114 def test_possible_input_mutation(self) -> None: 115 def f(x: torch.Tensor) -> torch.Tensor: 116 return torch.add(torch.ones(5), torch.ones(5), out=x) 117 118 with self.assertRaisesRegex( 119 SpecViolationError, 120 r"operator .* is not functional", 121 ): 122 exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge() 123 124 def test_tensor_spec_for_const_tensors(self) -> None: 125 class Module(torch.nn.Module): 126 def __init__(self) -> None: 127 super(Module, self).__init__() 128 self.linear = torch.nn.Linear(2, 3) 129 130 def forward(self, x: torch.Tensor) -> torch.Tensor: 131 return self.linear(x) 132 133 def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: 134 return (torch.randn(2),) 135 136 model = Module() 137 graph_module = ( 138 exir.capture(model, model.get_random_inputs(), exir.CaptureConfig()) 139 # torch._ops.aten.t.default 140 .to_edge( 141 exir.EdgeCompileConfig(_check_ir_validity=False) 142 ).exported_program.graph_module 143 ) 144 num_get_attr_node = 0 145 num_get_attr_node_with_tensorspec = 0 146 for nd in graph_module.graph.nodes: 147 if nd.op == "get_attr": 148 num_get_attr_node += 1 149 if nd.meta.get("val") is not None: 150 num_get_attr_node_with_tensorspec += 1 151 152 self.assertEqual(2, num_get_attr_node) 153 self.assertEqual(2, num_get_attr_node_with_tensorspec) 154 155 def test_multiple_returns_spec(self) -> None: 156 def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 157 return torch.ops.aten.max.dim(x, 0, False) 158 159 cnt = 0 160 module = ( 161 exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig()) 162 .to_edge() 163 .exported_program.graph_module 164 ) 165 for node in module.graph.nodes: 166 if node.target == exir_ops.edge.aten.max.dim: 167 cnt += 1 168 self.assertIsInstance(node.meta["val"], tuple) 169 self.assertEqual(cnt, 1) 170 171 def test_multiple_returns_pt2_mode(self) -> None: 172 def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 173 a = x * x 174 b = x + a 175 return a, b 176 177 inputs = (torch.ones(1, 2, 3),) 178 orig_res = f(*inputs) 179 module = ( 180 exir.capture( 181 f, 182 inputs, 183 exir.CaptureConfig(), 184 ) 185 .to_edge() 186 .exported_program.graph_module 187 ) 188 new_res = module(*inputs) 189 for node in module.graph.nodes: 190 if node.op == "output": 191 self.assertIsInstance(node.meta["val"], list) 192 self.assertEqual(len(node.meta["val"]), 2) 193 194 self.assertTrue(torch.allclose(orig_res[0], new_res[0])) 195 self.assertTrue(torch.allclose(orig_res[1], new_res[1])) 196 197 def test_dynamo_capture_scalar_outputs(self) -> None: 198 def f(x: torch.Tensor) -> float: 199 return x.item() 200 201 gm, guards = dynamo_trace( 202 f, 203 (torch.ones(1),), 204 False, 205 "real", 206 ExirDynamoConfig(), 207 ) 208 209 # pyre-ignore 210 @parameterized.expand([("stock_tensor",)]) 211 def test_embedding_dynamic_shape(self, input_type: str) -> None: 212 class Module(torch.nn.Module): 213 def __init__(self) -> None: 214 super().__init__() 215 216 def forward(self, x): 217 return x + x 218 219 example_input = torch.ones(10, dtype=torch.int64) 220 m = Module() 221 gm = ( 222 exir.capture( 223 m, 224 (example_input,), 225 exir.CaptureConfig( 226 enable_functionalization=False, 227 enable_dynamic_shape=True, 228 ), 229 ) 230 .to_edge() 231 .exported_program.graph_module 232 ) 233 234 print(gm.graph) 235 236 def test_dynamic_shape(self) -> None: 237 def forward(x: torch.Tensor) -> torch.Tensor: 238 x = x.view(x.shape[0] - 1, -1) 239 return torch.cat([x, x]) 240 241 gm = ( 242 exir.capture( 243 forward, 244 (torch.ones(3, 2, dtype=torch.int64),), 245 exir.CaptureConfig( 246 enable_functionalization=False, 247 enable_dynamic_shape=True, 248 _dynamo_config=ExirDynamoConfig(assume_static_by_default=True), 249 ), 250 # sym_size is not reg op 251 ) 252 .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) 253 .exported_program.graph_module 254 ) 255 256 for node in gm.graph.nodes: 257 if node.op in ("placeholder", "call_function"): 258 self.assertIn("val", node.meta) 259 260 def test_dynamo_frontend_container_input(self) -> None: 261 class Module(torch.nn.Module): 262 def __init__(self) -> None: 263 super(Module, self).__init__() 264 265 def forward( 266 self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] 267 ) -> torch.Tensor: 268 a = x[0] 269 b = x[1] 270 cum = 0 271 for i in b: 272 cum += i.sum() 273 return a.cos() + cum.sin() 274 275 with using_dynamo(True): 276 inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),) 277 gm = exir.capture(Module(), inp, exir.CaptureConfig()) 278 self.assertTrue(torch.allclose(Module()(*inp), gm(*inp))) 279 280 # TODO (tmanlaibaatar) remove this test 281 def test_pt2_mode_with_dynamo_config(self) -> None: 282 def f(x: torch.Tensor) -> torch.Tensor: 283 return x[: x.shape[0] - 1] 284 285 inp = (torch.randn(4, 5),) 286 prog = exir.capture( 287 f, 288 inp, 289 # missing dispatch key 290 ).to_edge() 291 self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3) 292 293 def test_input_container_type(self) -> None: 294 def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]: 295 # pyre-ignore 296 return {"a": x.sum() + sum(y).sum()} 297 298 inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)]) 299 300 # pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule, 301 # Set[torch._guards.Guard]]` into 2 values. 302 gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic") 303 prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge() 304 305 self.assertEqual(prog(*inp), f(*inp)) 306 307 def test_aot_buffer_mutation(self) -> None: 308 class Module(torch.nn.Module): 309 def __init__(self): 310 super().__init__() 311 self.register_buffer( 312 "_bin_num_examples", 313 torch.empty([42]).fill_( 314 0.0, 315 ), 316 ) 317 318 def forward(self, x, y, z): 319 self._bin_num_examples.index_copy_( 320 dim=0, 321 index=y, 322 source=z, 323 ) 324 self._bin_num_examples.index_add_( 325 dim=0, index=torch.arange(4), source=x 326 ) 327 return self._bin_num_examples - 1, x * z 328 329 model = Module() 330 example_inputs = ( 331 torch.randn(4, requires_grad=True), 332 torch.tensor(0), 333 torch.tensor(3.14), 334 ) 335 336 with self.assertRaisesRegex( 337 RuntimeError, 338 "Found a graph input that requires gradients, and received a mutation.", 339 ): 340 _ = exir.capture( 341 model, 342 example_inputs, 343 exir.CaptureConfig( 344 enable_aot=True, 345 ), 346 ) 347 348 # Note that model._bin_num_examples is mutated during exir.capture 349 # We need to create a new_model 350 new_model = Module() 351 example_inputs = ( 352 torch.randn(4), 353 torch.tensor(0), 354 torch.tensor(3.14), 355 ) 356 357 ep = exir.capture( 358 new_model, 359 example_inputs, 360 exir.CaptureConfig( 361 enable_aot=True, 362 ), 363 ) 364 365 test_inputs = ( 366 torch.randn(4), 367 torch.tensor(0), 368 torch.tensor(2.1), 369 ) 370 graph_outputs = ep(*test_inputs) 371 eager_outputs = Module()(*test_inputs) 372 self.assertEqual(len(graph_outputs), 2) 373 self.assertEqual(len(eager_outputs), 2) 374 self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0])) 375 self.assertTrue(torch.allclose(graph_outputs[1], eager_outputs[1])) 376 377 def test_assume_constant_by_default_prop(self) -> None: 378 def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 379 if x.shape[0] > 3: 380 return x.cos() 381 return x.sin() 382 383 dynamo_config = ExirDynamoConfig(assume_static_by_default=True) 384 capture_config = exir.CaptureConfig( 385 enable_dynamic_shape=True, _dynamo_config=dynamo_config 386 ) 387 captured = exir.capture( 388 foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config 389 ).exported_program.graph_module 390 found = False 391 for node in captured.graph.nodes: 392 # at least one input needs to have concrete dims 393 if "val" in node.meta: 394 fake_val = node.meta["val"] 395 for dim in fake_val.shape: 396 if is_concrete_int(dim): 397 found = True 398 399 self.assertTrue(found) 400 401 def test_aot_config(self) -> None: 402 class FooWithBuffer(torch.nn.Module): 403 def __init__(self): 404 super().__init__() 405 self.register_buffer("buffer", torch.zeros(42)) 406 407 def forward(self, x): 408 return x.cos() + self.buffer.sum() 409 410 capture_config = exir.CaptureConfig(enable_aot=True) 411 captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config) 412 captured_gm = captured_ep.exported_program.graph_module 413 414 placeholder_nodes = set() 415 print(captured_gm.graph) 416 for node in captured_gm.graph.nodes: 417 self.assertFalse(node.op == "get_attr") 418 if node.op == "placeholder": 419 placeholder_nodes.add(node) 420 if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor: 421 # make sure the placeholders are used 422 arg_0, arg_1 = node.args 423 self.assertEqual( 424 placeholder_nodes, 425 { 426 list(arg_0._input_nodes.keys())[0], 427 list(arg_1._input_nodes.keys())[0], 428 }, 429 ) 430 431 self.assertEqual(len(placeholder_nodes), 2) 432 captured_ep.to_edge() 433 434 def test_export_unlift(self) -> None: 435 class Foo(torch.nn.Module): 436 def __init__(self): 437 super().__init__() 438 self.register_buffer("buffer", torch.ones(6, 4)) 439 440 def forward(self, x): 441 return x.cos() + self.buffer.sin() 442 443 ep = exir.capture( 444 Foo(), 445 (torch.ones(6, 4),), 446 exir.CaptureConfig(enable_aot=True, _unlift=True), 447 ) 448 449 self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) 450 451 def test_export_container_unlift(self) -> None: 452 class FooContainerInputOutput(torch.nn.Module): 453 def __init__(self): 454 super().__init__() 455 self.register_buffer("buffer", torch.ones(6, 4)) 456 457 def forward(self, x): 458 return x[0][0].cos() + x[0][1].sin() + self.buffer.sin() 459 460 inp = ((torch.ones(6, 4), torch.ones(6, 4)),) 461 ep = exir.capture( 462 FooContainerInputOutput(), 463 (inp,), 464 CaptureConfig(enable_aot=True, _unlift=True), 465 ) 466 self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp))) 467 468 def test_export_container_input_unlift(self) -> None: 469 class FooContainerInputOutputV2(torch.nn.Module): 470 def __init__(self): 471 super().__init__() 472 self.register_buffer("buffer", torch.ones(6, 4)) 473 474 def forward(self, x, y): 475 return x[0].cos() + y[0].sin() + self.buffer.sin() 476 477 inp = ((torch.ones(6, 4),), (torch.ones(6, 4),)) 478 ep = exir.capture( 479 FooContainerInputOutputV2(), 480 inp, 481 CaptureConfig(enable_aot=True, _unlift=True), 482 ) 483 self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp))) 484 485 def test_export_cond(self) -> None: 486 class A(torch.nn.Module): 487 def __init__(self): 488 super().__init__() 489 self.register_buffer("buffer", torch.ones(6, 4)) 490 491 def forward(self): 492 return self.buffer.cos() 493 494 class Foo(torch.nn.Module): 495 def __init__(self): 496 super().__init__() 497 self.a = A() 498 499 def forward(self, x): 500 def true_fn(x): 501 return x.cos() + self.a().sum() 502 503 def false_fn(x): 504 return x.sin() 505 506 return cond(x.shape[0] > 4, true_fn, false_fn, [x]) 507 508 inp = torch.ones(6, 4) 509 ep = exir.capture( 510 Foo(), 511 (inp,), 512 CaptureConfig(enable_aot=True, _unlift=True), 513 ) 514 self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))) 515 516 def test_export_cond_map(self) -> None: 517 class A(torch.nn.Module): 518 def __init__(self): 519 super().__init__() 520 self.register_buffer("buffer", torch.ones(6, 4)) 521 522 def forward(self): 523 return self.buffer.sum() 524 525 class Module(torch.nn.Module): 526 def __init__(self): 527 super().__init__() 528 self.a = A() 529 530 def inner(self, x, pred): 531 def true_fn(x): 532 return x + x + self.a() 533 534 def false_fn(x): 535 return x * x - self.a() 536 537 return cond(pred, true_fn, false_fn, [x]) 538 539 def forward(self, pred, xs): 540 def body(x, pred): 541 return self.inner(x, pred) + self.a() 542 543 return map(body, xs, pred) 544 545 inp = torch.randn(3, 2, 1) 546 ep = exir.capture( 547 Module(), 548 (torch.tensor(True), inp), 549 CaptureConfig(enable_aot=True, _unlift=True), 550 ) 551 552 inp_test = torch.randn(3, 2, 1) 553 self.assertTrue( 554 torch.allclose( 555 ep(torch.tensor(True), inp_test), 556 Module()(torch.tensor(True), inp_test), 557 ) 558 ) 559 self.assertTrue( 560 torch.allclose( 561 ep(torch.tensor(False), inp_test), 562 Module()(torch.tensor(False), inp_test), 563 ) 564 ) 565