xref: /aosp_15_r20/external/executorch/exir/tests/test_tracer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import copy
10import unittest
11from typing import Dict, List, Tuple
12
13import executorch.exir as exir
14import executorch.exir.tests.models as models
15
16import torch
17
18from executorch.exir import CaptureConfig
19from executorch.exir.dialects._ops import ops as exir_ops
20from executorch.exir.tests.common import register_additional_test_aten_ops
21from executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo
22from functorch.experimental.control_flow import cond, map
23
24from parameterized import parameterized
25from torch._export.verifier import SpecViolationError
26from torch.fx.experimental.symbolic_shapes import is_concrete_int
27from torch.testing import FileCheck
28
29
30class TestTorchDispatchFXTracer(unittest.TestCase):
31    @classmethod
32    def setUpClass(cls) -> None:
33        register_additional_test_aten_ops()
34
35    def test_simple(self) -> None:
36        f = models.BasicSinMax()
37        f = (
38            exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
39            .to_edge()
40            .exported_program.graph_module
41        )
42
43        FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code)
44
45    def test_static_control_flow(self) -> None:
46        def f(pred: bool, x: torch.Tensor) -> torch.Tensor:
47            if pred:
48                return torch.sin(x).max()
49            else:
50                return torch.sin(x)
51
52        pred = True
53        x = torch.randn(100)
54        f_true = (
55            exir.capture(f, (pred, x), exir.CaptureConfig())
56            .to_edge()
57            .exported_program.graph_module
58        )
59
60        FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check(
61            "executorch_exir_dialects_edge__ops_aten_max"
62        ).run(f_true.code)
63
64        pred = False
65        f_false = (
66            exir.capture(f, (pred, x), exir.CaptureConfig())
67            .to_edge()
68            .exported_program.graph_module
69        )
70        FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not(
71            "executorch_exir_dialects_edge__ops_aten_max"
72        ).run(f_false.code)
73
74    def test_copy(self) -> None:
75        f = models.BasicSinMax()
76        f = (
77            exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
78            .to_edge()
79            .exported_program.graph_module
80        )
81
82        self.assertTrue(isinstance(f, torch.fx.GraphModule))
83        g = copy.deepcopy(f)
84        self.assertTrue(isinstance(g, torch.fx.GraphModule))
85
86    def test_stacktrace(self) -> None:
87        def f(x: torch.Tensor) -> torch.Tensor:
88            return x + x
89
90        traced_f = (
91            exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig())
92            .to_edge()
93            .exported_program.graph_module
94        )
95        # Check that stacktrace is populated and retained (by checking twice)
96        self.assertTrue(
97            any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
98        )
99        self.assertTrue(
100            any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
101        )
102
103    def test_ones(self) -> None:
104        class M(torch.nn.Module):
105            def forward(self, x):
106                y = torch.ones(x.shape[0])
107                return x + y
108
109        ep = torch.export.export(
110            M(), (torch.ones(3),), dynamic_shapes={"x": {0: torch.export.Dim("x")}}
111        )
112        exir.to_edge(ep)
113
114    def test_possible_input_mutation(self) -> None:
115        def f(x: torch.Tensor) -> torch.Tensor:
116            return torch.add(torch.ones(5), torch.ones(5), out=x)
117
118        with self.assertRaisesRegex(
119            SpecViolationError,
120            r"operator .* is not functional",
121        ):
122            exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge()
123
124    def test_tensor_spec_for_const_tensors(self) -> None:
125        class Module(torch.nn.Module):
126            def __init__(self) -> None:
127                super(Module, self).__init__()
128                self.linear = torch.nn.Linear(2, 3)
129
130            def forward(self, x: torch.Tensor) -> torch.Tensor:
131                return self.linear(x)
132
133            def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
134                return (torch.randn(2),)
135
136        model = Module()
137        graph_module = (
138            exir.capture(model, model.get_random_inputs(), exir.CaptureConfig())
139            # torch._ops.aten.t.default
140            .to_edge(
141                exir.EdgeCompileConfig(_check_ir_validity=False)
142            ).exported_program.graph_module
143        )
144        num_get_attr_node = 0
145        num_get_attr_node_with_tensorspec = 0
146        for nd in graph_module.graph.nodes:
147            if nd.op == "get_attr":
148                num_get_attr_node += 1
149                if nd.meta.get("val") is not None:
150                    num_get_attr_node_with_tensorspec += 1
151
152        self.assertEqual(2, num_get_attr_node)
153        self.assertEqual(2, num_get_attr_node_with_tensorspec)
154
155    def test_multiple_returns_spec(self) -> None:
156        def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
157            return torch.ops.aten.max.dim(x, 0, False)
158
159        cnt = 0
160        module = (
161            exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig())
162            .to_edge()
163            .exported_program.graph_module
164        )
165        for node in module.graph.nodes:
166            if node.target == exir_ops.edge.aten.max.dim:
167                cnt += 1
168                self.assertIsInstance(node.meta["val"], tuple)
169        self.assertEqual(cnt, 1)
170
171    def test_multiple_returns_pt2_mode(self) -> None:
172        def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
173            a = x * x
174            b = x + a
175            return a, b
176
177        inputs = (torch.ones(1, 2, 3),)
178        orig_res = f(*inputs)
179        module = (
180            exir.capture(
181                f,
182                inputs,
183                exir.CaptureConfig(),
184            )
185            .to_edge()
186            .exported_program.graph_module
187        )
188        new_res = module(*inputs)
189        for node in module.graph.nodes:
190            if node.op == "output":
191                self.assertIsInstance(node.meta["val"], list)
192                self.assertEqual(len(node.meta["val"]), 2)
193
194        self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
195        self.assertTrue(torch.allclose(orig_res[1], new_res[1]))
196
197    def test_dynamo_capture_scalar_outputs(self) -> None:
198        def f(x: torch.Tensor) -> float:
199            return x.item()
200
201        gm, guards = dynamo_trace(
202            f,
203            (torch.ones(1),),
204            False,
205            "real",
206            ExirDynamoConfig(),
207        )
208
209    # pyre-ignore
210    @parameterized.expand([("stock_tensor",)])
211    def test_embedding_dynamic_shape(self, input_type: str) -> None:
212        class Module(torch.nn.Module):
213            def __init__(self) -> None:
214                super().__init__()
215
216            def forward(self, x):
217                return x + x
218
219        example_input = torch.ones(10, dtype=torch.int64)
220        m = Module()
221        gm = (
222            exir.capture(
223                m,
224                (example_input,),
225                exir.CaptureConfig(
226                    enable_functionalization=False,
227                    enable_dynamic_shape=True,
228                ),
229            )
230            .to_edge()
231            .exported_program.graph_module
232        )
233
234        print(gm.graph)
235
236    def test_dynamic_shape(self) -> None:
237        def forward(x: torch.Tensor) -> torch.Tensor:
238            x = x.view(x.shape[0] - 1, -1)
239            return torch.cat([x, x])
240
241        gm = (
242            exir.capture(
243                forward,
244                (torch.ones(3, 2, dtype=torch.int64),),
245                exir.CaptureConfig(
246                    enable_functionalization=False,
247                    enable_dynamic_shape=True,
248                    _dynamo_config=ExirDynamoConfig(assume_static_by_default=True),
249                ),
250                # sym_size is not reg op
251            )
252            .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
253            .exported_program.graph_module
254        )
255
256        for node in gm.graph.nodes:
257            if node.op in ("placeholder", "call_function"):
258                self.assertIn("val", node.meta)
259
260    def test_dynamo_frontend_container_input(self) -> None:
261        class Module(torch.nn.Module):
262            def __init__(self) -> None:
263                super(Module, self).__init__()
264
265            def forward(
266                self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
267            ) -> torch.Tensor:
268                a = x[0]
269                b = x[1]
270                cum = 0
271                for i in b:
272                    cum += i.sum()
273                return a.cos() + cum.sin()
274
275        with using_dynamo(True):
276            inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),)
277            gm = exir.capture(Module(), inp, exir.CaptureConfig())
278            self.assertTrue(torch.allclose(Module()(*inp), gm(*inp)))
279
280    # TODO (tmanlaibaatar) remove this test
281    def test_pt2_mode_with_dynamo_config(self) -> None:
282        def f(x: torch.Tensor) -> torch.Tensor:
283            return x[: x.shape[0] - 1]
284
285        inp = (torch.randn(4, 5),)
286        prog = exir.capture(
287            f,
288            inp,
289            # missing dispatch key
290        ).to_edge()
291        self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3)
292
293    def test_input_container_type(self) -> None:
294        def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
295            # pyre-ignore
296            return {"a": x.sum() + sum(y).sum()}
297
298        inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
299
300        # pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule,
301        #  Set[torch._guards.Guard]]` into 2 values.
302        gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic")
303        prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge()
304
305        self.assertEqual(prog(*inp), f(*inp))
306
307    def test_aot_buffer_mutation(self) -> None:
308        class Module(torch.nn.Module):
309            def __init__(self):
310                super().__init__()
311                self.register_buffer(
312                    "_bin_num_examples",
313                    torch.empty([42]).fill_(
314                        0.0,
315                    ),
316                )
317
318            def forward(self, x, y, z):
319                self._bin_num_examples.index_copy_(
320                    dim=0,
321                    index=y,
322                    source=z,
323                )
324                self._bin_num_examples.index_add_(
325                    dim=0, index=torch.arange(4), source=x
326                )
327                return self._bin_num_examples - 1, x * z
328
329        model = Module()
330        example_inputs = (
331            torch.randn(4, requires_grad=True),
332            torch.tensor(0),
333            torch.tensor(3.14),
334        )
335
336        with self.assertRaisesRegex(
337            RuntimeError,
338            "Found a graph input that requires gradients, and received a mutation.",
339        ):
340            _ = exir.capture(
341                model,
342                example_inputs,
343                exir.CaptureConfig(
344                    enable_aot=True,
345                ),
346            )
347
348        # Note that model._bin_num_examples is mutated during exir.capture
349        # We need to create a new_model
350        new_model = Module()
351        example_inputs = (
352            torch.randn(4),
353            torch.tensor(0),
354            torch.tensor(3.14),
355        )
356
357        ep = exir.capture(
358            new_model,
359            example_inputs,
360            exir.CaptureConfig(
361                enable_aot=True,
362            ),
363        )
364
365        test_inputs = (
366            torch.randn(4),
367            torch.tensor(0),
368            torch.tensor(2.1),
369        )
370        graph_outputs = ep(*test_inputs)
371        eager_outputs = Module()(*test_inputs)
372        self.assertEqual(len(graph_outputs), 2)
373        self.assertEqual(len(eager_outputs), 2)
374        self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0]))
375        self.assertTrue(torch.allclose(graph_outputs[1], eager_outputs[1]))
376
377    def test_assume_constant_by_default_prop(self) -> None:
378        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
379            if x.shape[0] > 3:
380                return x.cos()
381            return x.sin()
382
383        dynamo_config = ExirDynamoConfig(assume_static_by_default=True)
384        capture_config = exir.CaptureConfig(
385            enable_dynamic_shape=True, _dynamo_config=dynamo_config
386        )
387        captured = exir.capture(
388            foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config
389        ).exported_program.graph_module
390        found = False
391        for node in captured.graph.nodes:
392            # at least one input needs to have concrete dims
393            if "val" in node.meta:
394                fake_val = node.meta["val"]
395                for dim in fake_val.shape:
396                    if is_concrete_int(dim):
397                        found = True
398
399        self.assertTrue(found)
400
401    def test_aot_config(self) -> None:
402        class FooWithBuffer(torch.nn.Module):
403            def __init__(self):
404                super().__init__()
405                self.register_buffer("buffer", torch.zeros(42))
406
407            def forward(self, x):
408                return x.cos() + self.buffer.sum()
409
410        capture_config = exir.CaptureConfig(enable_aot=True)
411        captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
412        captured_gm = captured_ep.exported_program.graph_module
413
414        placeholder_nodes = set()
415        print(captured_gm.graph)
416        for node in captured_gm.graph.nodes:
417            self.assertFalse(node.op == "get_attr")
418            if node.op == "placeholder":
419                placeholder_nodes.add(node)
420            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
421                # make sure the placeholders are used
422                arg_0, arg_1 = node.args
423                self.assertEqual(
424                    placeholder_nodes,
425                    {
426                        list(arg_0._input_nodes.keys())[0],
427                        list(arg_1._input_nodes.keys())[0],
428                    },
429                )
430
431        self.assertEqual(len(placeholder_nodes), 2)
432        captured_ep.to_edge()
433
434    def test_export_unlift(self) -> None:
435        class Foo(torch.nn.Module):
436            def __init__(self):
437                super().__init__()
438                self.register_buffer("buffer", torch.ones(6, 4))
439
440            def forward(self, x):
441                return x.cos() + self.buffer.sin()
442
443        ep = exir.capture(
444            Foo(),
445            (torch.ones(6, 4),),
446            exir.CaptureConfig(enable_aot=True, _unlift=True),
447        )
448
449        self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4))))
450
451    def test_export_container_unlift(self) -> None:
452        class FooContainerInputOutput(torch.nn.Module):
453            def __init__(self):
454                super().__init__()
455                self.register_buffer("buffer", torch.ones(6, 4))
456
457            def forward(self, x):
458                return x[0][0].cos() + x[0][1].sin() + self.buffer.sin()
459
460        inp = ((torch.ones(6, 4), torch.ones(6, 4)),)
461        ep = exir.capture(
462            FooContainerInputOutput(),
463            (inp,),
464            CaptureConfig(enable_aot=True, _unlift=True),
465        )
466        self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp)))
467
468    def test_export_container_input_unlift(self) -> None:
469        class FooContainerInputOutputV2(torch.nn.Module):
470            def __init__(self):
471                super().__init__()
472                self.register_buffer("buffer", torch.ones(6, 4))
473
474            def forward(self, x, y):
475                return x[0].cos() + y[0].sin() + self.buffer.sin()
476
477        inp = ((torch.ones(6, 4),), (torch.ones(6, 4),))
478        ep = exir.capture(
479            FooContainerInputOutputV2(),
480            inp,
481            CaptureConfig(enable_aot=True, _unlift=True),
482        )
483        self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp)))
484
485    def test_export_cond(self) -> None:
486        class A(torch.nn.Module):
487            def __init__(self):
488                super().__init__()
489                self.register_buffer("buffer", torch.ones(6, 4))
490
491            def forward(self):
492                return self.buffer.cos()
493
494        class Foo(torch.nn.Module):
495            def __init__(self):
496                super().__init__()
497                self.a = A()
498
499            def forward(self, x):
500                def true_fn(x):
501                    return x.cos() + self.a().sum()
502
503                def false_fn(x):
504                    return x.sin()
505
506                return cond(x.shape[0] > 4, true_fn, false_fn, [x])
507
508        inp = torch.ones(6, 4)
509        ep = exir.capture(
510            Foo(),
511            (inp,),
512            CaptureConfig(enable_aot=True, _unlift=True),
513        )
514        self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4))))
515
516    def test_export_cond_map(self) -> None:
517        class A(torch.nn.Module):
518            def __init__(self):
519                super().__init__()
520                self.register_buffer("buffer", torch.ones(6, 4))
521
522            def forward(self):
523                return self.buffer.sum()
524
525        class Module(torch.nn.Module):
526            def __init__(self):
527                super().__init__()
528                self.a = A()
529
530            def inner(self, x, pred):
531                def true_fn(x):
532                    return x + x + self.a()
533
534                def false_fn(x):
535                    return x * x - self.a()
536
537                return cond(pred, true_fn, false_fn, [x])
538
539            def forward(self, pred, xs):
540                def body(x, pred):
541                    return self.inner(x, pred) + self.a()
542
543                return map(body, xs, pred)
544
545        inp = torch.randn(3, 2, 1)
546        ep = exir.capture(
547            Module(),
548            (torch.tensor(True), inp),
549            CaptureConfig(enable_aot=True, _unlift=True),
550        )
551
552        inp_test = torch.randn(3, 2, 1)
553        self.assertTrue(
554            torch.allclose(
555                ep(torch.tensor(True), inp_test),
556                Module()(torch.tensor(True), inp_test),
557            )
558        )
559        self.assertTrue(
560            torch.allclose(
561                ep(torch.tensor(False), inp_test),
562                Module()(torch.tensor(False), inp_test),
563            )
564        )
565