xref: /aosp_15_r20/external/executorch/exir/tests/test_tracer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport copy
10*523fa7a6SAndroid Build Coastguard Workerimport unittest
11*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, List, Tuple
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir as exir
14*523fa7a6SAndroid Build Coastguard Workerimport executorch.exir.tests.models as models
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerimport torch
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir import CaptureConfig
19*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.dialects._ops import ops as exir_ops
20*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tests.common import register_additional_test_aten_ops
21*523fa7a6SAndroid Build Coastguard Workerfrom executorch.exir.tracer import dynamo_trace, ExirDynamoConfig, using_dynamo
22*523fa7a6SAndroid Build Coastguard Workerfrom functorch.experimental.control_flow import cond, map
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Workerfrom parameterized import parameterized
25*523fa7a6SAndroid Build Coastguard Workerfrom torch._export.verifier import SpecViolationError
26*523fa7a6SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import is_concrete_int
27*523fa7a6SAndroid Build Coastguard Workerfrom torch.testing import FileCheck
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerclass TestTorchDispatchFXTracer(unittest.TestCase):
31*523fa7a6SAndroid Build Coastguard Worker    @classmethod
32*523fa7a6SAndroid Build Coastguard Worker    def setUpClass(cls) -> None:
33*523fa7a6SAndroid Build Coastguard Worker        register_additional_test_aten_ops()
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker    def test_simple(self) -> None:
36*523fa7a6SAndroid Build Coastguard Worker        f = models.BasicSinMax()
37*523fa7a6SAndroid Build Coastguard Worker        f = (
38*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
39*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
40*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
41*523fa7a6SAndroid Build Coastguard Worker        )
42*523fa7a6SAndroid Build Coastguard Worker
43*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").run(f.code)
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker    def test_static_control_flow(self) -> None:
46*523fa7a6SAndroid Build Coastguard Worker        def f(pred: bool, x: torch.Tensor) -> torch.Tensor:
47*523fa7a6SAndroid Build Coastguard Worker            if pred:
48*523fa7a6SAndroid Build Coastguard Worker                return torch.sin(x).max()
49*523fa7a6SAndroid Build Coastguard Worker            else:
50*523fa7a6SAndroid Build Coastguard Worker                return torch.sin(x)
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker        pred = True
53*523fa7a6SAndroid Build Coastguard Worker        x = torch.randn(100)
54*523fa7a6SAndroid Build Coastguard Worker        f_true = (
55*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, (pred, x), exir.CaptureConfig())
56*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
57*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
58*523fa7a6SAndroid Build Coastguard Worker        )
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check(
61*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_aten_max"
62*523fa7a6SAndroid Build Coastguard Worker        ).run(f_true.code)
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker        pred = False
65*523fa7a6SAndroid Build Coastguard Worker        f_false = (
66*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, (pred, x), exir.CaptureConfig())
67*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
68*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
69*523fa7a6SAndroid Build Coastguard Worker        )
70*523fa7a6SAndroid Build Coastguard Worker        FileCheck().check("executorch_exir_dialects_edge__ops_aten_sin").check_not(
71*523fa7a6SAndroid Build Coastguard Worker            "executorch_exir_dialects_edge__ops_aten_max"
72*523fa7a6SAndroid Build Coastguard Worker        ).run(f_false.code)
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    def test_copy(self) -> None:
75*523fa7a6SAndroid Build Coastguard Worker        f = models.BasicSinMax()
76*523fa7a6SAndroid Build Coastguard Worker        f = (
77*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, f.get_random_inputs(), exir.CaptureConfig())
78*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
79*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
80*523fa7a6SAndroid Build Coastguard Worker        )
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(isinstance(f, torch.fx.GraphModule))
83*523fa7a6SAndroid Build Coastguard Worker        g = copy.deepcopy(f)
84*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(isinstance(g, torch.fx.GraphModule))
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker    def test_stacktrace(self) -> None:
87*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> torch.Tensor:
88*523fa7a6SAndroid Build Coastguard Worker            return x + x
89*523fa7a6SAndroid Build Coastguard Worker
90*523fa7a6SAndroid Build Coastguard Worker        traced_f = (
91*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, (torch.rand(2, 2),), exir.CaptureConfig())
92*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
93*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
94*523fa7a6SAndroid Build Coastguard Worker        )
95*523fa7a6SAndroid Build Coastguard Worker        # Check that stacktrace is populated and retained (by checking twice)
96*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
97*523fa7a6SAndroid Build Coastguard Worker            any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
98*523fa7a6SAndroid Build Coastguard Worker        )
99*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
100*523fa7a6SAndroid Build Coastguard Worker            any(node.meta.get("stack_trace", None) for node in traced_f.graph.nodes)
101*523fa7a6SAndroid Build Coastguard Worker        )
102*523fa7a6SAndroid Build Coastguard Worker
103*523fa7a6SAndroid Build Coastguard Worker    def test_ones(self) -> None:
104*523fa7a6SAndroid Build Coastguard Worker        class M(torch.nn.Module):
105*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
106*523fa7a6SAndroid Build Coastguard Worker                y = torch.ones(x.shape[0])
107*523fa7a6SAndroid Build Coastguard Worker                return x + y
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker        ep = torch.export.export(
110*523fa7a6SAndroid Build Coastguard Worker            M(), (torch.ones(3),), dynamic_shapes={"x": {0: torch.export.Dim("x")}}
111*523fa7a6SAndroid Build Coastguard Worker        )
112*523fa7a6SAndroid Build Coastguard Worker        exir.to_edge(ep)
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker    def test_possible_input_mutation(self) -> None:
115*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> torch.Tensor:
116*523fa7a6SAndroid Build Coastguard Worker            return torch.add(torch.ones(5), torch.ones(5), out=x)
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
119*523fa7a6SAndroid Build Coastguard Worker            SpecViolationError,
120*523fa7a6SAndroid Build Coastguard Worker            r"operator .* is not functional",
121*523fa7a6SAndroid Build Coastguard Worker        ):
122*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, (torch.zeros(5),), exir.CaptureConfig()).to_edge()
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker    def test_tensor_spec_for_const_tensors(self) -> None:
125*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
126*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
127*523fa7a6SAndroid Build Coastguard Worker                super(Module, self).__init__()
128*523fa7a6SAndroid Build Coastguard Worker                self.linear = torch.nn.Linear(2, 3)
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
131*523fa7a6SAndroid Build Coastguard Worker                return self.linear(x)
132*523fa7a6SAndroid Build Coastguard Worker
133*523fa7a6SAndroid Build Coastguard Worker            def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
134*523fa7a6SAndroid Build Coastguard Worker                return (torch.randn(2),)
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker        model = Module()
137*523fa7a6SAndroid Build Coastguard Worker        graph_module = (
138*523fa7a6SAndroid Build Coastguard Worker            exir.capture(model, model.get_random_inputs(), exir.CaptureConfig())
139*523fa7a6SAndroid Build Coastguard Worker            # torch._ops.aten.t.default
140*523fa7a6SAndroid Build Coastguard Worker            .to_edge(
141*523fa7a6SAndroid Build Coastguard Worker                exir.EdgeCompileConfig(_check_ir_validity=False)
142*523fa7a6SAndroid Build Coastguard Worker            ).exported_program.graph_module
143*523fa7a6SAndroid Build Coastguard Worker        )
144*523fa7a6SAndroid Build Coastguard Worker        num_get_attr_node = 0
145*523fa7a6SAndroid Build Coastguard Worker        num_get_attr_node_with_tensorspec = 0
146*523fa7a6SAndroid Build Coastguard Worker        for nd in graph_module.graph.nodes:
147*523fa7a6SAndroid Build Coastguard Worker            if nd.op == "get_attr":
148*523fa7a6SAndroid Build Coastguard Worker                num_get_attr_node += 1
149*523fa7a6SAndroid Build Coastguard Worker                if nd.meta.get("val") is not None:
150*523fa7a6SAndroid Build Coastguard Worker                    num_get_attr_node_with_tensorspec += 1
151*523fa7a6SAndroid Build Coastguard Worker
152*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(2, num_get_attr_node)
153*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(2, num_get_attr_node_with_tensorspec)
154*523fa7a6SAndroid Build Coastguard Worker
155*523fa7a6SAndroid Build Coastguard Worker    def test_multiple_returns_spec(self) -> None:
156*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
157*523fa7a6SAndroid Build Coastguard Worker            return torch.ops.aten.max.dim(x, 0, False)
158*523fa7a6SAndroid Build Coastguard Worker
159*523fa7a6SAndroid Build Coastguard Worker        cnt = 0
160*523fa7a6SAndroid Build Coastguard Worker        module = (
161*523fa7a6SAndroid Build Coastguard Worker            exir.capture(f, (torch.zeros(1, 2, 3),), exir.CaptureConfig())
162*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
163*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
164*523fa7a6SAndroid Build Coastguard Worker        )
165*523fa7a6SAndroid Build Coastguard Worker        for node in module.graph.nodes:
166*523fa7a6SAndroid Build Coastguard Worker            if node.target == exir_ops.edge.aten.max.dim:
167*523fa7a6SAndroid Build Coastguard Worker                cnt += 1
168*523fa7a6SAndroid Build Coastguard Worker                self.assertIsInstance(node.meta["val"], tuple)
169*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(cnt, 1)
170*523fa7a6SAndroid Build Coastguard Worker
171*523fa7a6SAndroid Build Coastguard Worker    def test_multiple_returns_pt2_mode(self) -> None:
172*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
173*523fa7a6SAndroid Build Coastguard Worker            a = x * x
174*523fa7a6SAndroid Build Coastguard Worker            b = x + a
175*523fa7a6SAndroid Build Coastguard Worker            return a, b
176*523fa7a6SAndroid Build Coastguard Worker
177*523fa7a6SAndroid Build Coastguard Worker        inputs = (torch.ones(1, 2, 3),)
178*523fa7a6SAndroid Build Coastguard Worker        orig_res = f(*inputs)
179*523fa7a6SAndroid Build Coastguard Worker        module = (
180*523fa7a6SAndroid Build Coastguard Worker            exir.capture(
181*523fa7a6SAndroid Build Coastguard Worker                f,
182*523fa7a6SAndroid Build Coastguard Worker                inputs,
183*523fa7a6SAndroid Build Coastguard Worker                exir.CaptureConfig(),
184*523fa7a6SAndroid Build Coastguard Worker            )
185*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
186*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
187*523fa7a6SAndroid Build Coastguard Worker        )
188*523fa7a6SAndroid Build Coastguard Worker        new_res = module(*inputs)
189*523fa7a6SAndroid Build Coastguard Worker        for node in module.graph.nodes:
190*523fa7a6SAndroid Build Coastguard Worker            if node.op == "output":
191*523fa7a6SAndroid Build Coastguard Worker                self.assertIsInstance(node.meta["val"], list)
192*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(len(node.meta["val"]), 2)
193*523fa7a6SAndroid Build Coastguard Worker
194*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(orig_res[0], new_res[0]))
195*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(orig_res[1], new_res[1]))
196*523fa7a6SAndroid Build Coastguard Worker
197*523fa7a6SAndroid Build Coastguard Worker    def test_dynamo_capture_scalar_outputs(self) -> None:
198*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> float:
199*523fa7a6SAndroid Build Coastguard Worker            return x.item()
200*523fa7a6SAndroid Build Coastguard Worker
201*523fa7a6SAndroid Build Coastguard Worker        gm, guards = dynamo_trace(
202*523fa7a6SAndroid Build Coastguard Worker            f,
203*523fa7a6SAndroid Build Coastguard Worker            (torch.ones(1),),
204*523fa7a6SAndroid Build Coastguard Worker            False,
205*523fa7a6SAndroid Build Coastguard Worker            "real",
206*523fa7a6SAndroid Build Coastguard Worker            ExirDynamoConfig(),
207*523fa7a6SAndroid Build Coastguard Worker        )
208*523fa7a6SAndroid Build Coastguard Worker
209*523fa7a6SAndroid Build Coastguard Worker    # pyre-ignore
210*523fa7a6SAndroid Build Coastguard Worker    @parameterized.expand([("stock_tensor",)])
211*523fa7a6SAndroid Build Coastguard Worker    def test_embedding_dynamic_shape(self, input_type: str) -> None:
212*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
213*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
214*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
215*523fa7a6SAndroid Build Coastguard Worker
216*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
217*523fa7a6SAndroid Build Coastguard Worker                return x + x
218*523fa7a6SAndroid Build Coastguard Worker
219*523fa7a6SAndroid Build Coastguard Worker        example_input = torch.ones(10, dtype=torch.int64)
220*523fa7a6SAndroid Build Coastguard Worker        m = Module()
221*523fa7a6SAndroid Build Coastguard Worker        gm = (
222*523fa7a6SAndroid Build Coastguard Worker            exir.capture(
223*523fa7a6SAndroid Build Coastguard Worker                m,
224*523fa7a6SAndroid Build Coastguard Worker                (example_input,),
225*523fa7a6SAndroid Build Coastguard Worker                exir.CaptureConfig(
226*523fa7a6SAndroid Build Coastguard Worker                    enable_functionalization=False,
227*523fa7a6SAndroid Build Coastguard Worker                    enable_dynamic_shape=True,
228*523fa7a6SAndroid Build Coastguard Worker                ),
229*523fa7a6SAndroid Build Coastguard Worker            )
230*523fa7a6SAndroid Build Coastguard Worker            .to_edge()
231*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
232*523fa7a6SAndroid Build Coastguard Worker        )
233*523fa7a6SAndroid Build Coastguard Worker
234*523fa7a6SAndroid Build Coastguard Worker        print(gm.graph)
235*523fa7a6SAndroid Build Coastguard Worker
236*523fa7a6SAndroid Build Coastguard Worker    def test_dynamic_shape(self) -> None:
237*523fa7a6SAndroid Build Coastguard Worker        def forward(x: torch.Tensor) -> torch.Tensor:
238*523fa7a6SAndroid Build Coastguard Worker            x = x.view(x.shape[0] - 1, -1)
239*523fa7a6SAndroid Build Coastguard Worker            return torch.cat([x, x])
240*523fa7a6SAndroid Build Coastguard Worker
241*523fa7a6SAndroid Build Coastguard Worker        gm = (
242*523fa7a6SAndroid Build Coastguard Worker            exir.capture(
243*523fa7a6SAndroid Build Coastguard Worker                forward,
244*523fa7a6SAndroid Build Coastguard Worker                (torch.ones(3, 2, dtype=torch.int64),),
245*523fa7a6SAndroid Build Coastguard Worker                exir.CaptureConfig(
246*523fa7a6SAndroid Build Coastguard Worker                    enable_functionalization=False,
247*523fa7a6SAndroid Build Coastguard Worker                    enable_dynamic_shape=True,
248*523fa7a6SAndroid Build Coastguard Worker                    _dynamo_config=ExirDynamoConfig(assume_static_by_default=True),
249*523fa7a6SAndroid Build Coastguard Worker                ),
250*523fa7a6SAndroid Build Coastguard Worker                # sym_size is not reg op
251*523fa7a6SAndroid Build Coastguard Worker            )
252*523fa7a6SAndroid Build Coastguard Worker            .to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
253*523fa7a6SAndroid Build Coastguard Worker            .exported_program.graph_module
254*523fa7a6SAndroid Build Coastguard Worker        )
255*523fa7a6SAndroid Build Coastguard Worker
256*523fa7a6SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
257*523fa7a6SAndroid Build Coastguard Worker            if node.op in ("placeholder", "call_function"):
258*523fa7a6SAndroid Build Coastguard Worker                self.assertIn("val", node.meta)
259*523fa7a6SAndroid Build Coastguard Worker
260*523fa7a6SAndroid Build Coastguard Worker    def test_dynamo_frontend_container_input(self) -> None:
261*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
262*523fa7a6SAndroid Build Coastguard Worker            def __init__(self) -> None:
263*523fa7a6SAndroid Build Coastguard Worker                super(Module, self).__init__()
264*523fa7a6SAndroid Build Coastguard Worker
265*523fa7a6SAndroid Build Coastguard Worker            def forward(
266*523fa7a6SAndroid Build Coastguard Worker                self, x: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
267*523fa7a6SAndroid Build Coastguard Worker            ) -> torch.Tensor:
268*523fa7a6SAndroid Build Coastguard Worker                a = x[0]
269*523fa7a6SAndroid Build Coastguard Worker                b = x[1]
270*523fa7a6SAndroid Build Coastguard Worker                cum = 0
271*523fa7a6SAndroid Build Coastguard Worker                for i in b:
272*523fa7a6SAndroid Build Coastguard Worker                    cum += i.sum()
273*523fa7a6SAndroid Build Coastguard Worker                return a.cos() + cum.sin()
274*523fa7a6SAndroid Build Coastguard Worker
275*523fa7a6SAndroid Build Coastguard Worker        with using_dynamo(True):
276*523fa7a6SAndroid Build Coastguard Worker            inp = ((torch.ones(6), (torch.ones(6), torch.ones(6))),)
277*523fa7a6SAndroid Build Coastguard Worker            gm = exir.capture(Module(), inp, exir.CaptureConfig())
278*523fa7a6SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(Module()(*inp), gm(*inp)))
279*523fa7a6SAndroid Build Coastguard Worker
280*523fa7a6SAndroid Build Coastguard Worker    # TODO (tmanlaibaatar) remove this test
281*523fa7a6SAndroid Build Coastguard Worker    def test_pt2_mode_with_dynamo_config(self) -> None:
282*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor) -> torch.Tensor:
283*523fa7a6SAndroid Build Coastguard Worker            return x[: x.shape[0] - 1]
284*523fa7a6SAndroid Build Coastguard Worker
285*523fa7a6SAndroid Build Coastguard Worker        inp = (torch.randn(4, 5),)
286*523fa7a6SAndroid Build Coastguard Worker        prog = exir.capture(
287*523fa7a6SAndroid Build Coastguard Worker            f,
288*523fa7a6SAndroid Build Coastguard Worker            inp,
289*523fa7a6SAndroid Build Coastguard Worker            # missing dispatch key
290*523fa7a6SAndroid Build Coastguard Worker        ).to_edge()
291*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(prog(torch.randn(4, 5)).shape[0], 3)
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker    def test_input_container_type(self) -> None:
294*523fa7a6SAndroid Build Coastguard Worker        def f(x: torch.Tensor, y: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
295*523fa7a6SAndroid Build Coastguard Worker            # pyre-ignore
296*523fa7a6SAndroid Build Coastguard Worker            return {"a": x.sum() + sum(y).sum()}
297*523fa7a6SAndroid Build Coastguard Worker
298*523fa7a6SAndroid Build Coastguard Worker        inp = (torch.randn(6, 5), [torch.randn(6, 5), torch.randn(6, 5)])
299*523fa7a6SAndroid Build Coastguard Worker
300*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[23]: Unable to unpack `(...) -> Tuple[GraphModule,
301*523fa7a6SAndroid Build Coastguard Worker        #  Set[torch._guards.Guard]]` into 2 values.
302*523fa7a6SAndroid Build Coastguard Worker        gm, _ = torch._dynamo.export(f, *inp, aten_graph=True, tracing_mode="symbolic")
303*523fa7a6SAndroid Build Coastguard Worker        prog = exir.capture(f, inp, config=exir.CaptureConfig()).to_edge()
304*523fa7a6SAndroid Build Coastguard Worker
305*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(prog(*inp), f(*inp))
306*523fa7a6SAndroid Build Coastguard Worker
307*523fa7a6SAndroid Build Coastguard Worker    def test_aot_buffer_mutation(self) -> None:
308*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
309*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
310*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
311*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer(
312*523fa7a6SAndroid Build Coastguard Worker                    "_bin_num_examples",
313*523fa7a6SAndroid Build Coastguard Worker                    torch.empty([42]).fill_(
314*523fa7a6SAndroid Build Coastguard Worker                        0.0,
315*523fa7a6SAndroid Build Coastguard Worker                    ),
316*523fa7a6SAndroid Build Coastguard Worker                )
317*523fa7a6SAndroid Build Coastguard Worker
318*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y, z):
319*523fa7a6SAndroid Build Coastguard Worker                self._bin_num_examples.index_copy_(
320*523fa7a6SAndroid Build Coastguard Worker                    dim=0,
321*523fa7a6SAndroid Build Coastguard Worker                    index=y,
322*523fa7a6SAndroid Build Coastguard Worker                    source=z,
323*523fa7a6SAndroid Build Coastguard Worker                )
324*523fa7a6SAndroid Build Coastguard Worker                self._bin_num_examples.index_add_(
325*523fa7a6SAndroid Build Coastguard Worker                    dim=0, index=torch.arange(4), source=x
326*523fa7a6SAndroid Build Coastguard Worker                )
327*523fa7a6SAndroid Build Coastguard Worker                return self._bin_num_examples - 1, x * z
328*523fa7a6SAndroid Build Coastguard Worker
329*523fa7a6SAndroid Build Coastguard Worker        model = Module()
330*523fa7a6SAndroid Build Coastguard Worker        example_inputs = (
331*523fa7a6SAndroid Build Coastguard Worker            torch.randn(4, requires_grad=True),
332*523fa7a6SAndroid Build Coastguard Worker            torch.tensor(0),
333*523fa7a6SAndroid Build Coastguard Worker            torch.tensor(3.14),
334*523fa7a6SAndroid Build Coastguard Worker        )
335*523fa7a6SAndroid Build Coastguard Worker
336*523fa7a6SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
337*523fa7a6SAndroid Build Coastguard Worker            RuntimeError,
338*523fa7a6SAndroid Build Coastguard Worker            "Found a graph input that requires gradients, and received a mutation.",
339*523fa7a6SAndroid Build Coastguard Worker        ):
340*523fa7a6SAndroid Build Coastguard Worker            _ = exir.capture(
341*523fa7a6SAndroid Build Coastguard Worker                model,
342*523fa7a6SAndroid Build Coastguard Worker                example_inputs,
343*523fa7a6SAndroid Build Coastguard Worker                exir.CaptureConfig(
344*523fa7a6SAndroid Build Coastguard Worker                    enable_aot=True,
345*523fa7a6SAndroid Build Coastguard Worker                ),
346*523fa7a6SAndroid Build Coastguard Worker            )
347*523fa7a6SAndroid Build Coastguard Worker
348*523fa7a6SAndroid Build Coastguard Worker        # Note that model._bin_num_examples is mutated during exir.capture
349*523fa7a6SAndroid Build Coastguard Worker        # We need to create a new_model
350*523fa7a6SAndroid Build Coastguard Worker        new_model = Module()
351*523fa7a6SAndroid Build Coastguard Worker        example_inputs = (
352*523fa7a6SAndroid Build Coastguard Worker            torch.randn(4),
353*523fa7a6SAndroid Build Coastguard Worker            torch.tensor(0),
354*523fa7a6SAndroid Build Coastguard Worker            torch.tensor(3.14),
355*523fa7a6SAndroid Build Coastguard Worker        )
356*523fa7a6SAndroid Build Coastguard Worker
357*523fa7a6SAndroid Build Coastguard Worker        ep = exir.capture(
358*523fa7a6SAndroid Build Coastguard Worker            new_model,
359*523fa7a6SAndroid Build Coastguard Worker            example_inputs,
360*523fa7a6SAndroid Build Coastguard Worker            exir.CaptureConfig(
361*523fa7a6SAndroid Build Coastguard Worker                enable_aot=True,
362*523fa7a6SAndroid Build Coastguard Worker            ),
363*523fa7a6SAndroid Build Coastguard Worker        )
364*523fa7a6SAndroid Build Coastguard Worker
365*523fa7a6SAndroid Build Coastguard Worker        test_inputs = (
366*523fa7a6SAndroid Build Coastguard Worker            torch.randn(4),
367*523fa7a6SAndroid Build Coastguard Worker            torch.tensor(0),
368*523fa7a6SAndroid Build Coastguard Worker            torch.tensor(2.1),
369*523fa7a6SAndroid Build Coastguard Worker        )
370*523fa7a6SAndroid Build Coastguard Worker        graph_outputs = ep(*test_inputs)
371*523fa7a6SAndroid Build Coastguard Worker        eager_outputs = Module()(*test_inputs)
372*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(graph_outputs), 2)
373*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(eager_outputs), 2)
374*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(graph_outputs[0], eager_outputs[0]))
375*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(graph_outputs[1], eager_outputs[1]))
376*523fa7a6SAndroid Build Coastguard Worker
377*523fa7a6SAndroid Build Coastguard Worker    def test_assume_constant_by_default_prop(self) -> None:
378*523fa7a6SAndroid Build Coastguard Worker        def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
379*523fa7a6SAndroid Build Coastguard Worker            if x.shape[0] > 3:
380*523fa7a6SAndroid Build Coastguard Worker                return x.cos()
381*523fa7a6SAndroid Build Coastguard Worker            return x.sin()
382*523fa7a6SAndroid Build Coastguard Worker
383*523fa7a6SAndroid Build Coastguard Worker        dynamo_config = ExirDynamoConfig(assume_static_by_default=True)
384*523fa7a6SAndroid Build Coastguard Worker        capture_config = exir.CaptureConfig(
385*523fa7a6SAndroid Build Coastguard Worker            enable_dynamic_shape=True, _dynamo_config=dynamo_config
386*523fa7a6SAndroid Build Coastguard Worker        )
387*523fa7a6SAndroid Build Coastguard Worker        captured = exir.capture(
388*523fa7a6SAndroid Build Coastguard Worker            foo, (torch.ones(6, 2), torch.ones(6, 3)), capture_config
389*523fa7a6SAndroid Build Coastguard Worker        ).exported_program.graph_module
390*523fa7a6SAndroid Build Coastguard Worker        found = False
391*523fa7a6SAndroid Build Coastguard Worker        for node in captured.graph.nodes:
392*523fa7a6SAndroid Build Coastguard Worker            # at least one input needs to have concrete dims
393*523fa7a6SAndroid Build Coastguard Worker            if "val" in node.meta:
394*523fa7a6SAndroid Build Coastguard Worker                fake_val = node.meta["val"]
395*523fa7a6SAndroid Build Coastguard Worker                for dim in fake_val.shape:
396*523fa7a6SAndroid Build Coastguard Worker                    if is_concrete_int(dim):
397*523fa7a6SAndroid Build Coastguard Worker                        found = True
398*523fa7a6SAndroid Build Coastguard Worker
399*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(found)
400*523fa7a6SAndroid Build Coastguard Worker
401*523fa7a6SAndroid Build Coastguard Worker    def test_aot_config(self) -> None:
402*523fa7a6SAndroid Build Coastguard Worker        class FooWithBuffer(torch.nn.Module):
403*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
404*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
405*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("buffer", torch.zeros(42))
406*523fa7a6SAndroid Build Coastguard Worker
407*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
408*523fa7a6SAndroid Build Coastguard Worker                return x.cos() + self.buffer.sum()
409*523fa7a6SAndroid Build Coastguard Worker
410*523fa7a6SAndroid Build Coastguard Worker        capture_config = exir.CaptureConfig(enable_aot=True)
411*523fa7a6SAndroid Build Coastguard Worker        captured_ep = exir.capture(FooWithBuffer(), (torch.ones(6, 2),), capture_config)
412*523fa7a6SAndroid Build Coastguard Worker        captured_gm = captured_ep.exported_program.graph_module
413*523fa7a6SAndroid Build Coastguard Worker
414*523fa7a6SAndroid Build Coastguard Worker        placeholder_nodes = set()
415*523fa7a6SAndroid Build Coastguard Worker        print(captured_gm.graph)
416*523fa7a6SAndroid Build Coastguard Worker        for node in captured_gm.graph.nodes:
417*523fa7a6SAndroid Build Coastguard Worker            self.assertFalse(node.op == "get_attr")
418*523fa7a6SAndroid Build Coastguard Worker            if node.op == "placeholder":
419*523fa7a6SAndroid Build Coastguard Worker                placeholder_nodes.add(node)
420*523fa7a6SAndroid Build Coastguard Worker            if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
421*523fa7a6SAndroid Build Coastguard Worker                # make sure the placeholders are used
422*523fa7a6SAndroid Build Coastguard Worker                arg_0, arg_1 = node.args
423*523fa7a6SAndroid Build Coastguard Worker                self.assertEqual(
424*523fa7a6SAndroid Build Coastguard Worker                    placeholder_nodes,
425*523fa7a6SAndroid Build Coastguard Worker                    {
426*523fa7a6SAndroid Build Coastguard Worker                        list(arg_0._input_nodes.keys())[0],
427*523fa7a6SAndroid Build Coastguard Worker                        list(arg_1._input_nodes.keys())[0],
428*523fa7a6SAndroid Build Coastguard Worker                    },
429*523fa7a6SAndroid Build Coastguard Worker                )
430*523fa7a6SAndroid Build Coastguard Worker
431*523fa7a6SAndroid Build Coastguard Worker        self.assertEqual(len(placeholder_nodes), 2)
432*523fa7a6SAndroid Build Coastguard Worker        captured_ep.to_edge()
433*523fa7a6SAndroid Build Coastguard Worker
434*523fa7a6SAndroid Build Coastguard Worker    def test_export_unlift(self) -> None:
435*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
436*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
437*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
438*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("buffer", torch.ones(6, 4))
439*523fa7a6SAndroid Build Coastguard Worker
440*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
441*523fa7a6SAndroid Build Coastguard Worker                return x.cos() + self.buffer.sin()
442*523fa7a6SAndroid Build Coastguard Worker
443*523fa7a6SAndroid Build Coastguard Worker        ep = exir.capture(
444*523fa7a6SAndroid Build Coastguard Worker            Foo(),
445*523fa7a6SAndroid Build Coastguard Worker            (torch.ones(6, 4),),
446*523fa7a6SAndroid Build Coastguard Worker            exir.CaptureConfig(enable_aot=True, _unlift=True),
447*523fa7a6SAndroid Build Coastguard Worker        )
448*523fa7a6SAndroid Build Coastguard Worker
449*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4))))
450*523fa7a6SAndroid Build Coastguard Worker
451*523fa7a6SAndroid Build Coastguard Worker    def test_export_container_unlift(self) -> None:
452*523fa7a6SAndroid Build Coastguard Worker        class FooContainerInputOutput(torch.nn.Module):
453*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
454*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
455*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("buffer", torch.ones(6, 4))
456*523fa7a6SAndroid Build Coastguard Worker
457*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
458*523fa7a6SAndroid Build Coastguard Worker                return x[0][0].cos() + x[0][1].sin() + self.buffer.sin()
459*523fa7a6SAndroid Build Coastguard Worker
460*523fa7a6SAndroid Build Coastguard Worker        inp = ((torch.ones(6, 4), torch.ones(6, 4)),)
461*523fa7a6SAndroid Build Coastguard Worker        ep = exir.capture(
462*523fa7a6SAndroid Build Coastguard Worker            FooContainerInputOutput(),
463*523fa7a6SAndroid Build Coastguard Worker            (inp,),
464*523fa7a6SAndroid Build Coastguard Worker            CaptureConfig(enable_aot=True, _unlift=True),
465*523fa7a6SAndroid Build Coastguard Worker        )
466*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ep(inp), FooContainerInputOutput()(inp)))
467*523fa7a6SAndroid Build Coastguard Worker
468*523fa7a6SAndroid Build Coastguard Worker    def test_export_container_input_unlift(self) -> None:
469*523fa7a6SAndroid Build Coastguard Worker        class FooContainerInputOutputV2(torch.nn.Module):
470*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
471*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
472*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("buffer", torch.ones(6, 4))
473*523fa7a6SAndroid Build Coastguard Worker
474*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x, y):
475*523fa7a6SAndroid Build Coastguard Worker                return x[0].cos() + y[0].sin() + self.buffer.sin()
476*523fa7a6SAndroid Build Coastguard Worker
477*523fa7a6SAndroid Build Coastguard Worker        inp = ((torch.ones(6, 4),), (torch.ones(6, 4),))
478*523fa7a6SAndroid Build Coastguard Worker        ep = exir.capture(
479*523fa7a6SAndroid Build Coastguard Worker            FooContainerInputOutputV2(),
480*523fa7a6SAndroid Build Coastguard Worker            inp,
481*523fa7a6SAndroid Build Coastguard Worker            CaptureConfig(enable_aot=True, _unlift=True),
482*523fa7a6SAndroid Build Coastguard Worker        )
483*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ep(*inp), FooContainerInputOutputV2()(*inp)))
484*523fa7a6SAndroid Build Coastguard Worker
485*523fa7a6SAndroid Build Coastguard Worker    def test_export_cond(self) -> None:
486*523fa7a6SAndroid Build Coastguard Worker        class A(torch.nn.Module):
487*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
488*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
489*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("buffer", torch.ones(6, 4))
490*523fa7a6SAndroid Build Coastguard Worker
491*523fa7a6SAndroid Build Coastguard Worker            def forward(self):
492*523fa7a6SAndroid Build Coastguard Worker                return self.buffer.cos()
493*523fa7a6SAndroid Build Coastguard Worker
494*523fa7a6SAndroid Build Coastguard Worker        class Foo(torch.nn.Module):
495*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
496*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
497*523fa7a6SAndroid Build Coastguard Worker                self.a = A()
498*523fa7a6SAndroid Build Coastguard Worker
499*523fa7a6SAndroid Build Coastguard Worker            def forward(self, x):
500*523fa7a6SAndroid Build Coastguard Worker                def true_fn(x):
501*523fa7a6SAndroid Build Coastguard Worker                    return x.cos() + self.a().sum()
502*523fa7a6SAndroid Build Coastguard Worker
503*523fa7a6SAndroid Build Coastguard Worker                def false_fn(x):
504*523fa7a6SAndroid Build Coastguard Worker                    return x.sin()
505*523fa7a6SAndroid Build Coastguard Worker
506*523fa7a6SAndroid Build Coastguard Worker                return cond(x.shape[0] > 4, true_fn, false_fn, [x])
507*523fa7a6SAndroid Build Coastguard Worker
508*523fa7a6SAndroid Build Coastguard Worker        inp = torch.ones(6, 4)
509*523fa7a6SAndroid Build Coastguard Worker        ep = exir.capture(
510*523fa7a6SAndroid Build Coastguard Worker            Foo(),
511*523fa7a6SAndroid Build Coastguard Worker            (inp,),
512*523fa7a6SAndroid Build Coastguard Worker            CaptureConfig(enable_aot=True, _unlift=True),
513*523fa7a6SAndroid Build Coastguard Worker        )
514*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(ep(torch.ones(6, 4)), Foo()(torch.ones(6, 4))))
515*523fa7a6SAndroid Build Coastguard Worker
516*523fa7a6SAndroid Build Coastguard Worker    def test_export_cond_map(self) -> None:
517*523fa7a6SAndroid Build Coastguard Worker        class A(torch.nn.Module):
518*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
519*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
520*523fa7a6SAndroid Build Coastguard Worker                self.register_buffer("buffer", torch.ones(6, 4))
521*523fa7a6SAndroid Build Coastguard Worker
522*523fa7a6SAndroid Build Coastguard Worker            def forward(self):
523*523fa7a6SAndroid Build Coastguard Worker                return self.buffer.sum()
524*523fa7a6SAndroid Build Coastguard Worker
525*523fa7a6SAndroid Build Coastguard Worker        class Module(torch.nn.Module):
526*523fa7a6SAndroid Build Coastguard Worker            def __init__(self):
527*523fa7a6SAndroid Build Coastguard Worker                super().__init__()
528*523fa7a6SAndroid Build Coastguard Worker                self.a = A()
529*523fa7a6SAndroid Build Coastguard Worker
530*523fa7a6SAndroid Build Coastguard Worker            def inner(self, x, pred):
531*523fa7a6SAndroid Build Coastguard Worker                def true_fn(x):
532*523fa7a6SAndroid Build Coastguard Worker                    return x + x + self.a()
533*523fa7a6SAndroid Build Coastguard Worker
534*523fa7a6SAndroid Build Coastguard Worker                def false_fn(x):
535*523fa7a6SAndroid Build Coastguard Worker                    return x * x - self.a()
536*523fa7a6SAndroid Build Coastguard Worker
537*523fa7a6SAndroid Build Coastguard Worker                return cond(pred, true_fn, false_fn, [x])
538*523fa7a6SAndroid Build Coastguard Worker
539*523fa7a6SAndroid Build Coastguard Worker            def forward(self, pred, xs):
540*523fa7a6SAndroid Build Coastguard Worker                def body(x, pred):
541*523fa7a6SAndroid Build Coastguard Worker                    return self.inner(x, pred) + self.a()
542*523fa7a6SAndroid Build Coastguard Worker
543*523fa7a6SAndroid Build Coastguard Worker                return map(body, xs, pred)
544*523fa7a6SAndroid Build Coastguard Worker
545*523fa7a6SAndroid Build Coastguard Worker        inp = torch.randn(3, 2, 1)
546*523fa7a6SAndroid Build Coastguard Worker        ep = exir.capture(
547*523fa7a6SAndroid Build Coastguard Worker            Module(),
548*523fa7a6SAndroid Build Coastguard Worker            (torch.tensor(True), inp),
549*523fa7a6SAndroid Build Coastguard Worker            CaptureConfig(enable_aot=True, _unlift=True),
550*523fa7a6SAndroid Build Coastguard Worker        )
551*523fa7a6SAndroid Build Coastguard Worker
552*523fa7a6SAndroid Build Coastguard Worker        inp_test = torch.randn(3, 2, 1)
553*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
554*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(
555*523fa7a6SAndroid Build Coastguard Worker                ep(torch.tensor(True), inp_test),
556*523fa7a6SAndroid Build Coastguard Worker                Module()(torch.tensor(True), inp_test),
557*523fa7a6SAndroid Build Coastguard Worker            )
558*523fa7a6SAndroid Build Coastguard Worker        )
559*523fa7a6SAndroid Build Coastguard Worker        self.assertTrue(
560*523fa7a6SAndroid Build Coastguard Worker            torch.allclose(
561*523fa7a6SAndroid Build Coastguard Worker                ep(torch.tensor(False), inp_test),
562*523fa7a6SAndroid Build Coastguard Worker                Module()(torch.tensor(False), inp_test),
563*523fa7a6SAndroid Build Coastguard Worker            )
564*523fa7a6SAndroid Build Coastguard Worker        )
565