xref: /aosp_15_r20/external/pytorch/test/jit/test_async.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Tuple
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable
12*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir)
14*da0073e9SAndroid Build Coastguard Workerfrom typing import List
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
17*da0073e9SAndroid Build Coastguard Workerfrom torch.jit import Future
18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import _inline_everything, JitTestCase
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerclass TestAsync(JitTestCase):
22*da0073e9SAndroid Build Coastguard Worker    def test_async_python(self):
23*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
24*da0073e9SAndroid Build Coastguard Worker        def foo(x):
25*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
28*da0073e9SAndroid Build Coastguard Worker        fut = torch.jit.fork(foo, x)
29*da0073e9SAndroid Build Coastguard Worker        y_hat = foo(x)
30*da0073e9SAndroid Build Coastguard Worker        y = torch.jit.wait(fut)
31*da0073e9SAndroid Build Coastguard Worker        # assert nothing; only to make sure the fake python path works
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker    def test_async_future_type_python(self):
34*da0073e9SAndroid Build Coastguard Worker        def foo(inp):
35*da0073e9SAndroid Build Coastguard Worker            futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], [])
36*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
37*da0073e9SAndroid Build Coastguard Worker                futures.append(torch.jit.fork(lambda x: x, inp))
38*da0073e9SAndroid Build Coastguard Worker            all_outputs = []
39*da0073e9SAndroid Build Coastguard Worker            for future in futures:
40*da0073e9SAndroid Build Coastguard Worker                all_outputs.append(torch.jit.wait(future))
41*da0073e9SAndroid Build Coastguard Worker            return all_outputs
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker        # assert nothing, just to make sure python type parsing works
44*da0073e9SAndroid Build Coastguard Worker        foo(torch.randn(3, 4))
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def test_async_parsing(self):
47*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
48*da0073e9SAndroid Build Coastguard Worker        def foo(x: Tensor) -> List[Tensor]:
49*da0073e9SAndroid Build Coastguard Worker            return [torch.neg(x), x.t()]
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
52*da0073e9SAndroid Build Coastguard Worker        def bar(x):
53*da0073e9SAndroid Build Coastguard Worker            futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
54*da0073e9SAndroid Build Coastguard Worker            for _ in range(3):
55*da0073e9SAndroid Build Coastguard Worker                future = torch.jit.annotate(
56*da0073e9SAndroid Build Coastguard Worker                    Future[List[Tensor]], torch.jit.fork(foo, x)
57*da0073e9SAndroid Build Coastguard Worker                )
58*da0073e9SAndroid Build Coastguard Worker                futures.append(future)
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker            output = torch.jit.annotate(List[List[Tensor]], [])
61*da0073e9SAndroid Build Coastguard Worker            for i in range(3):
62*da0073e9SAndroid Build Coastguard Worker                output.append(torch.jit.wait(futures[i]))
63*da0073e9SAndroid Build Coastguard Worker            return output
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 3)
66*da0073e9SAndroid Build Coastguard Worker        result = bar(x)
67*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(result), 3)
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    def test_async_script(self):
70*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
71*da0073e9SAndroid Build Coastguard Worker        def foo(x):
72*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x), x
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
77*da0073e9SAndroid Build Coastguard Worker        def wait_script(x):
78*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit.fork(foo, x)
79*da0073e9SAndroid Build Coastguard Worker            y_hat = foo(x)
80*da0073e9SAndroid Build Coastguard Worker            y = torch.jit.wait(fut)
81*da0073e9SAndroid Build Coastguard Worker            return y, y_hat
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        y, y_hat = wait_script(x)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_hat)
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_async_script_capture(self):
88*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.jit.ScriptModule):
89*da0073e9SAndroid Build Coastguard Worker            __constants__ = ["const"]
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
92*da0073e9SAndroid Build Coastguard Worker                super().__init__()
93*da0073e9SAndroid Build Coastguard Worker                self.const = 42
94*da0073e9SAndroid Build Coastguard Worker                self.param = nn.Parameter(torch.randn(2, 2))
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
97*da0073e9SAndroid Build Coastguard Worker            def foo(self, x1, x2):
98*da0073e9SAndroid Build Coastguard Worker                return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
101*da0073e9SAndroid Build Coastguard Worker            def forward(self, x1, x2):
102*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit.fork(self.foo, x1, x2)
103*da0073e9SAndroid Build Coastguard Worker                y_hat = self.foo(x1, x2)
104*da0073e9SAndroid Build Coastguard Worker                y = torch.jit.wait(fut)
105*da0073e9SAndroid Build Coastguard Worker                return y, y_hat
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand(3, 4)
108*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(5, 6)
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker        m = Mod()
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        with torch.jit.optimized_execution(False):
113*da0073e9SAndroid Build Coastguard Worker            y, y_hat = m.forward(x1, x2)
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_hat)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    def test_async_script_nested(self):
118*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
119*da0073e9SAndroid Build Coastguard Worker        def foo(x):
120*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x), x
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
125*da0073e9SAndroid Build Coastguard Worker        def wait_script(x):
126*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(foo, x)
127*da0073e9SAndroid Build Coastguard Worker            y_hat = foo(x)
128*da0073e9SAndroid Build Coastguard Worker            y = torch.jit._wait(fut)
129*da0073e9SAndroid Build Coastguard Worker            return y, y_hat
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
132*da0073e9SAndroid Build Coastguard Worker        def wait_script_nest(x):
133*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(wait_script, x)
134*da0073e9SAndroid Build Coastguard Worker            return torch.jit._wait(fut)
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        y, y_hat = wait_script_nest(x)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y_hat)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker    def test_async_script_no_script_mod(self):
141*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
142*da0073e9SAndroid Build Coastguard Worker
143*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
144*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "cannot call a value", "torch.jit._fork(x"
145*da0073e9SAndroid Build Coastguard Worker        ):
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
148*da0073e9SAndroid Build Coastguard Worker            def wait_script(x):
149*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(x)
150*da0073e9SAndroid Build Coastguard Worker                return fut
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker    def test_async_script_multi_waits(self):
153*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
154*da0073e9SAndroid Build Coastguard Worker        def foo(x):
155*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x).t() + x
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
158*da0073e9SAndroid Build Coastguard Worker        def wait_script(x):
159*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(foo, x)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker            # wait twice on the same future
162*da0073e9SAndroid Build Coastguard Worker            y1 = torch.jit._wait(fut)
163*da0073e9SAndroid Build Coastguard Worker            y2 = torch.jit._wait(fut)
164*da0073e9SAndroid Build Coastguard Worker            return y1, y2
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, 2)
167*da0073e9SAndroid Build Coastguard Worker        y1, y2 = wait_script(x)
168*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, y2)
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker    def test_async_script_multi_forks(self):
171*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
172*da0073e9SAndroid Build Coastguard Worker        def foo1(x):
173*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x).t() + x
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
176*da0073e9SAndroid Build Coastguard Worker        def foo2(x, y):
177*da0073e9SAndroid Build Coastguard Worker            return torch.neg(x).t() + x + torch.neg(y).t()
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
180*da0073e9SAndroid Build Coastguard Worker        def foo3(x, y, z):
181*da0073e9SAndroid Build Coastguard Worker            return torch.neg(z).t() + y.t() + x
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand(10, 10)
184*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(10, 10)
185*da0073e9SAndroid Build Coastguard Worker        x3 = torch.rand(10, 10)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
188*da0073e9SAndroid Build Coastguard Worker        def wait_script(x1, x2, x3):
189*da0073e9SAndroid Build Coastguard Worker            f1 = torch.jit._fork(foo1, x1)
190*da0073e9SAndroid Build Coastguard Worker            f2 = torch.jit._fork(foo2, x1, x2)
191*da0073e9SAndroid Build Coastguard Worker            f3 = torch.jit._fork(foo3, x1, x2, x3)
192*da0073e9SAndroid Build Coastguard Worker            f4 = torch.jit._fork(foo1, x2)
193*da0073e9SAndroid Build Coastguard Worker            f5 = torch.jit._fork(foo2, x2, x3)
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker            # ignore some forks
196*da0073e9SAndroid Build Coastguard Worker            y1 = torch.jit._wait(f1)
197*da0073e9SAndroid Build Coastguard Worker            y2 = torch.jit._wait(f2)
198*da0073e9SAndroid Build Coastguard Worker            y3 = torch.jit._wait(f3)
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker            return y1, y2, y3
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        y1, y2, y3 = wait_script(x1, x2, x3)
203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y1, foo1(x1))
204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y2, foo2(x1, x2))
205*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y3, foo3(x1, x2, x3))
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker    def test_async_kwargs(self):
208*da0073e9SAndroid Build Coastguard Worker        def foo(x1, x2):
209*da0073e9SAndroid Build Coastguard Worker            return 2 * x1 + x2
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        x1 = torch.rand(3, 4)
212*da0073e9SAndroid Build Coastguard Worker        x2 = torch.rand(3, 4)
213*da0073e9SAndroid Build Coastguard Worker        y_hat = foo(x1, x2)
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker        # Cover tracing and bare functions with permutations of args, kwargs
216*da0073e9SAndroid Build Coastguard Worker        for func in [
217*da0073e9SAndroid Build Coastguard Worker            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
218*da0073e9SAndroid Build Coastguard Worker            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
219*da0073e9SAndroid Build Coastguard Worker            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
220*da0073e9SAndroid Build Coastguard Worker            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
221*da0073e9SAndroid Build Coastguard Worker        ]:
222*da0073e9SAndroid Build Coastguard Worker            for wrapper in [
223*da0073e9SAndroid Build Coastguard Worker                func,
224*da0073e9SAndroid Build Coastguard Worker                torch.jit.trace(func, (x1, x2)),
225*da0073e9SAndroid Build Coastguard Worker            ]:
226*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(wrapper(x1, x2), y_hat)
227*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(wrapper(x1, x2=x2), y_hat)
228*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
229*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        # Cover scripting
232*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
233*da0073e9SAndroid Build Coastguard Worker        def foo_script_args(x1, x2):
234*da0073e9SAndroid Build Coastguard Worker            return torch.jit._wait(torch.jit._fork(foo, x1, x2))
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
237*da0073e9SAndroid Build Coastguard Worker        def foo_script_kwargs(x1, x2):
238*da0073e9SAndroid Build Coastguard Worker            return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        for wrapper in [
241*da0073e9SAndroid Build Coastguard Worker            foo_script_args,
242*da0073e9SAndroid Build Coastguard Worker            foo_script_kwargs,
243*da0073e9SAndroid Build Coastguard Worker        ]:
244*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(wrapper(x1, x2), y_hat)
245*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(wrapper(x1, x2=x2), y_hat)
246*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
247*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    @_inline_everything
250*da0073e9SAndroid Build Coastguard Worker    def test_async_script_trace(self):
251*da0073e9SAndroid Build Coastguard Worker        class Traced(nn.Module):
252*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
253*da0073e9SAndroid Build Coastguard Worker                return (torch.neg(x), x)
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.jit.ScriptModule):
256*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
257*da0073e9SAndroid Build Coastguard Worker                super().__init__()
258*da0073e9SAndroid Build Coastguard Worker                x = torch.rand(3, 3)
259*da0073e9SAndroid Build Coastguard Worker                self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script_method
262*da0073e9SAndroid Build Coastguard Worker            def forward(
263*da0073e9SAndroid Build Coastguard Worker                self, x: Tensor
264*da0073e9SAndroid Build Coastguard Worker            ) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
265*da0073e9SAndroid Build Coastguard Worker                future1 = torch.jit._fork(self.traced, x)
266*da0073e9SAndroid Build Coastguard Worker                future2 = torch.jit._fork(torch.neg, x)
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker                tensor_tuple = torch.jit._wait(future1)
269*da0073e9SAndroid Build Coastguard Worker                tensor_single = torch.jit._wait(future2)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker                tensor_list = []
272*da0073e9SAndroid Build Coastguard Worker                tensor_list.append(tensor_tuple[0])
273*da0073e9SAndroid Build Coastguard Worker                tensor_list.append(tensor_single)
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker                # return a nested structure of tensors
276*da0073e9SAndroid Build Coastguard Worker                return (tensor_list, tensor_tuple, tensor_tuple[1])
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        class TupleCl(nn.Module):
279*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
280*da0073e9SAndroid Build Coastguard Worker                super().__init__()
281*da0073e9SAndroid Build Coastguard Worker                self.module = Mod()
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
284*da0073e9SAndroid Build Coastguard Worker                z = torch.neg(x)
285*da0073e9SAndroid Build Coastguard Worker                y = self.module(x)
286*da0073e9SAndroid Build Coastguard Worker                list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
287*da0073e9SAndroid Build Coastguard Worker                return tuple(list)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 3)
290*da0073e9SAndroid Build Coastguard Worker        module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        # Make sure we have forks
293*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
294*da0073e9SAndroid Build Coastguard Worker            module.graph, kind="prim::fork", num_kind_nodes=2
295*da0073e9SAndroid Build Coastguard Worker        )
296*da0073e9SAndroid Build Coastguard Worker        # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
297*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
298*da0073e9SAndroid Build Coastguard Worker            module.graph, kind="aten::neg", num_kind_nodes=1
299*da0073e9SAndroid Build Coastguard Worker        )
300*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
301*da0073e9SAndroid Build Coastguard Worker            module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
302*da0073e9SAndroid Build Coastguard Worker        )
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        y = torch.neg(x)
305*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(module(x), (y, y, y, y, x, x))
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    def test_async_script_error(self):
308*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
311*da0073e9SAndroid Build Coastguard Worker        def foo(x):
312*da0073e9SAndroid Build Coastguard Worker            # error here
313*da0073e9SAndroid Build Coastguard Worker            return x.t() + x
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
316*da0073e9SAndroid Build Coastguard Worker        def wait_script(x):
317*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(foo, x)
318*da0073e9SAndroid Build Coastguard Worker            return torch.jit._wait(fut)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
321*da0073e9SAndroid Build Coastguard Worker        def wait_script_nest(x):
322*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(wait_script, x)
323*da0073e9SAndroid Build Coastguard Worker            return torch.jit._wait(fut)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker        # no future
326*da0073e9SAndroid Build Coastguard Worker        error_msg = "The size.*must match the size of tensor"
327*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
328*da0073e9SAndroid Build Coastguard Worker            foo(x)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        # one future
331*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
332*da0073e9SAndroid Build Coastguard Worker            Exception, error_msg, "torch.jit._fork(foo, x"
333*da0073e9SAndroid Build Coastguard Worker        ):
334*da0073e9SAndroid Build Coastguard Worker            wait_script(x)
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        # two futures with a different error
337*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4, 5)
338*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
339*da0073e9SAndroid Build Coastguard Worker            Exception,
340*da0073e9SAndroid Build Coastguard Worker            "expects a tensor with <= 2 dimensions",
341*da0073e9SAndroid Build Coastguard Worker            "torch.jit._fork(wait_script, x",
342*da0073e9SAndroid Build Coastguard Worker        ):
343*da0073e9SAndroid Build Coastguard Worker            wait_script_nest(x)
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker    def test_async_grad_guard_with_grad(self):
346*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
347*da0073e9SAndroid Build Coastguard Worker        def foo(x):
348*da0073e9SAndroid Build Coastguard Worker            y = x * 2
349*da0073e9SAndroid Build Coastguard Worker            return y.requires_grad
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
352*da0073e9SAndroid Build Coastguard Worker        def bar(x):
353*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(foo, x)
354*da0073e9SAndroid Build Coastguard Worker            requires_grad_in_fork = torch.jit._wait(fut)
355*da0073e9SAndroid Build Coastguard Worker            z = x * 2
356*da0073e9SAndroid Build Coastguard Worker            return (requires_grad_in_fork, z.requires_grad)
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
361*da0073e9SAndroid Build Coastguard Worker            (inside_fork, after_wait) = bar(x)
362*da0073e9SAndroid Build Coastguard Worker
363*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inside_fork, True)
364*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(after_wait, True)
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker    def test_async_grad_guard_no_grad(self):
367*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
368*da0073e9SAndroid Build Coastguard Worker        def foo(x):
369*da0073e9SAndroid Build Coastguard Worker            y = x * 2
370*da0073e9SAndroid Build Coastguard Worker            return y.requires_grad
371*da0073e9SAndroid Build Coastguard Worker
372*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
373*da0073e9SAndroid Build Coastguard Worker        def bar(x):
374*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(foo, x)
375*da0073e9SAndroid Build Coastguard Worker            requires_grad_in_fork = torch.jit._wait(fut)
376*da0073e9SAndroid Build Coastguard Worker            z = x * 2
377*da0073e9SAndroid Build Coastguard Worker            return (requires_grad_in_fork, z.requires_grad)
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
382*da0073e9SAndroid Build Coastguard Worker            (inside_fork, after_wait) = bar(x)
383*da0073e9SAndroid Build Coastguard Worker
384*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(inside_fork, False)
385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(after_wait, False)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def test_trace_fork_wait(self):
388*da0073e9SAndroid Build Coastguard Worker        def fork_body(x):
389*da0073e9SAndroid Build Coastguard Worker            return x.neg(), x.neg() + 1
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker        def fn(x):
392*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(fork_body, x)
393*da0073e9SAndroid Build Coastguard Worker            vals = torch.jit._wait(fut)
394*da0073e9SAndroid Build Coastguard Worker            return vals[0], vals[1], x - 1
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(fn, (torch.rand(3, 4),))
397*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn(x), traced(x))
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
401*da0073e9SAndroid Build Coastguard Worker            traced.graph, kind="prim::fork", num_kind_nodes=1
402*da0073e9SAndroid Build Coastguard Worker        )
403*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
404*da0073e9SAndroid Build Coastguard Worker            traced.graph, kind="aten::wait", num_kind_nodes=1
405*da0073e9SAndroid Build Coastguard Worker        )
406*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
407*da0073e9SAndroid Build Coastguard Worker            traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
408*da0073e9SAndroid Build Coastguard Worker        )
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker    def test_trace_fork_wait_leaking(self):
411*da0073e9SAndroid Build Coastguard Worker        my_list = []
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        def fork_body(x):
414*da0073e9SAndroid Build Coastguard Worker            my_list.append(x + 1)
415*da0073e9SAndroid Build Coastguard Worker            return x + 1
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        def fn(x):
418*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(fork_body, x)
419*da0073e9SAndroid Build Coastguard Worker            val = torch.jit._wait(fut)
420*da0073e9SAndroid Build Coastguard Worker            return my_list[0]
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
423*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
424*da0073e9SAndroid Build Coastguard Worker            "did not have observable data dependence with trace inputs; "
425*da0073e9SAndroid Build Coastguard Worker            "this probably indicates your program cannot be understood "
426*da0073e9SAndroid Build Coastguard Worker            "by the tracer.",
427*da0073e9SAndroid Build Coastguard Worker            "",
428*da0073e9SAndroid Build Coastguard Worker        ):
429*da0073e9SAndroid Build Coastguard Worker            traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker    def test_trace_fork_wait_inline(self):
432*da0073e9SAndroid Build Coastguard Worker        def fork_body(x):
433*da0073e9SAndroid Build Coastguard Worker            return x + 1, x + 2
434*da0073e9SAndroid Build Coastguard Worker
435*da0073e9SAndroid Build Coastguard Worker        def fn(x):
436*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(fork_body, x)
437*da0073e9SAndroid Build Coastguard Worker            val = torch.jit._wait(fut)
438*da0073e9SAndroid Build Coastguard Worker            return val[1]
439*da0073e9SAndroid Build Coastguard Worker
440*da0073e9SAndroid Build Coastguard Worker        traced = torch.jit.trace(fn, (torch.rand(3, 4),))
441*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_pass_inline_fork_wait(traced.graph)
442*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
443*da0073e9SAndroid Build Coastguard Worker            traced.graph, kind="prim::fork", num_kind_nodes=0
444*da0073e9SAndroid Build Coastguard Worker        )
445*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
446*da0073e9SAndroid Build Coastguard Worker            traced.graph, kind="aten::wait", num_kind_nodes=0
447*da0073e9SAndroid Build Coastguard Worker        )
448*da0073e9SAndroid Build Coastguard Worker        self.assertGraphContainsExactly(
449*da0073e9SAndroid Build Coastguard Worker            traced.graph, kind="aten::add", num_kind_nodes=2
450*da0073e9SAndroid Build Coastguard Worker        )
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    def test_trace_fork_wait_list_modulecalls(self):
453*da0073e9SAndroid Build Coastguard Worker        def add_one(input):
454*da0073e9SAndroid Build Coastguard Worker            return input + torch.ones(input.size())
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        class TestListFutureModule(nn.Module):
457*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
458*da0073e9SAndroid Build Coastguard Worker                input_list = []
459*da0073e9SAndroid Build Coastguard Worker                for i in range(3):
460*da0073e9SAndroid Build Coastguard Worker                    input_list.append(input)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker                fut_list: List[Future[torch.Tensor]] = []
463*da0073e9SAndroid Build Coastguard Worker                for input_tensor in input_list:
464*da0073e9SAndroid Build Coastguard Worker                    fut_list.append(torch.jit._fork(add_one, input_tensor))
465*da0073e9SAndroid Build Coastguard Worker                # return list[future[tensor]] here to ensure tracing
466*da0073e9SAndroid Build Coastguard Worker                # module calls return the correct types
467*da0073e9SAndroid Build Coastguard Worker                return fut_list
468*da0073e9SAndroid Build Coastguard Worker
469*da0073e9SAndroid Build Coastguard Worker        class TestModuleWrapper(nn.Module):
470*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
471*da0073e9SAndroid Build Coastguard Worker                super().__init__()
472*da0073e9SAndroid Build Coastguard Worker                self.list_fut_mod = TestListFutureModule()
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
475*da0073e9SAndroid Build Coastguard Worker                fut_list = self.list_fut_mod(input)
476*da0073e9SAndroid Build Coastguard Worker                res = input
477*da0073e9SAndroid Build Coastguard Worker                for fut in fut_list:
478*da0073e9SAndroid Build Coastguard Worker                    res = res + fut.wait()
479*da0073e9SAndroid Build Coastguard Worker                return res
480*da0073e9SAndroid Build Coastguard Worker
481*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),))
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker    def test_trace_modulecalls_with_different_output_types(self):
484*da0073e9SAndroid Build Coastguard Worker        def add_one(input):
485*da0073e9SAndroid Build Coastguard Worker            return input + torch.ones(input.size())
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker        class DifferentOutputModule(nn.Module):
488*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
489*da0073e9SAndroid Build Coastguard Worker                fut_res = torch.jit._fork(add_one, (input))
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker                # return different types from module call
492*da0073e9SAndroid Build Coastguard Worker                return input, fut_res
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker        class TestModule(nn.Module):
495*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
496*da0073e9SAndroid Build Coastguard Worker                super().__init__()
497*da0073e9SAndroid Build Coastguard Worker                self.gen_output = DifferentOutputModule()
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker            def forward(self, input):
500*da0073e9SAndroid Build Coastguard Worker                res, fut_res = self.gen_output(input)
501*da0073e9SAndroid Build Coastguard Worker                res = res + fut_res.wait()
502*da0073e9SAndroid Build Coastguard Worker                return res
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker        self.checkTrace(TestModule(), (torch.randn(5, 5),))
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker    def test_no_future_subtype_message(self):
507*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
508*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Future without a contained type", ""
509*da0073e9SAndroid Build Coastguard Worker        ):
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
512*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
513*da0073e9SAndroid Build Coastguard Worker                futs = torch.jit.annotate(List[torch.jit.Future], [])
514*da0073e9SAndroid Build Coastguard Worker
515*da0073e9SAndroid Build Coastguard Worker    def test_future_subtyping(self):
516*da0073e9SAndroid Build Coastguard Worker        """
517*da0073e9SAndroid Build Coastguard Worker        Test that futures subtype each other properly.
518*da0073e9SAndroid Build Coastguard Worker        """
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker        # Successful subtyping.
521*da0073e9SAndroid Build Coastguard Worker        def returns_int(x: int) -> int:
522*da0073e9SAndroid Build Coastguard Worker            return x + x + 1
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker        def returns_future_any(x: int) -> torch.jit.Future[Any]:
525*da0073e9SAndroid Build Coastguard Worker            return torch.jit._fork(returns_int, (x))
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
528*da0073e9SAndroid Build Coastguard Worker        def fn_int(x: int) -> Any:
529*da0073e9SAndroid Build Coastguard Worker            fut = returns_future_any(x)
530*da0073e9SAndroid Build Coastguard Worker            return fut.wait()
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker        # Unsuccessful subtyping.
533*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegexWithHighlight(
534*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
535*da0073e9SAndroid Build Coastguard Worker            r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
536*da0073e9SAndroid Build Coastguard Worker            "fut = returns_future_float(x",
537*da0073e9SAndroid Build Coastguard Worker        ):
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker            def returns_future_float(x: int) -> torch.jit.Future[float]:
540*da0073e9SAndroid Build Coastguard Worker                return torch.jit._fork(returns_int, (x))
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
543*da0073e9SAndroid Build Coastguard Worker            def fn_float(x: int) -> Any:
544*da0073e9SAndroid Build Coastguard Worker                fut = returns_future_float(x)
545*da0073e9SAndroid Build Coastguard Worker                return fut.wait()
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
549*da0073e9SAndroid Build Coastguard Worker    raise RuntimeError(
550*da0073e9SAndroid Build Coastguard Worker        "This test file is not meant to be run directly, use:\n\n"
551*da0073e9SAndroid Build Coastguard Worker        "\tpython test/test_jit.py TESTNAME\n\n"
552*da0073e9SAndroid Build Coastguard Worker        "instead."
553*da0073e9SAndroid Build Coastguard Worker    )
554