xref: /aosp_15_r20/external/pytorch/test/jit/test_async.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5from typing import Any, Tuple
6
7import torch
8import torch.nn as nn
9
10
11# Make the helper files in test/ importable
12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
13sys.path.append(pytorch_test_dir)
14from typing import List
15
16from torch import Tensor
17from torch.jit import Future
18from torch.testing._internal.jit_utils import _inline_everything, JitTestCase
19
20
21class TestAsync(JitTestCase):
22    def test_async_python(self):
23        @torch.jit.script
24        def foo(x):
25            return torch.neg(x)
26
27        x = torch.rand(3, 4)
28        fut = torch.jit.fork(foo, x)
29        y_hat = foo(x)
30        y = torch.jit.wait(fut)
31        # assert nothing; only to make sure the fake python path works
32
33    def test_async_future_type_python(self):
34        def foo(inp):
35            futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], [])
36            for i in range(5):
37                futures.append(torch.jit.fork(lambda x: x, inp))
38            all_outputs = []
39            for future in futures:
40                all_outputs.append(torch.jit.wait(future))
41            return all_outputs
42
43        # assert nothing, just to make sure python type parsing works
44        foo(torch.randn(3, 4))
45
46    def test_async_parsing(self):
47        @torch.jit.script
48        def foo(x: Tensor) -> List[Tensor]:
49            return [torch.neg(x), x.t()]
50
51        @torch.jit.script
52        def bar(x):
53            futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
54            for _ in range(3):
55                future = torch.jit.annotate(
56                    Future[List[Tensor]], torch.jit.fork(foo, x)
57                )
58                futures.append(future)
59
60            output = torch.jit.annotate(List[List[Tensor]], [])
61            for i in range(3):
62                output.append(torch.jit.wait(futures[i]))
63            return output
64
65        x = torch.rand(3, 3)
66        result = bar(x)
67        self.assertEqual(len(result), 3)
68
69    def test_async_script(self):
70        @torch.jit.script
71        def foo(x):
72            return torch.neg(x), x
73
74        x = torch.rand(3, 4)
75
76        @torch.jit.script
77        def wait_script(x):
78            fut = torch.jit.fork(foo, x)
79            y_hat = foo(x)
80            y = torch.jit.wait(fut)
81            return y, y_hat
82
83        y, y_hat = wait_script(x)
84
85        self.assertEqual(y, y_hat)
86
87    def test_async_script_capture(self):
88        class Mod(torch.jit.ScriptModule):
89            __constants__ = ["const"]
90
91            def __init__(self) -> None:
92                super().__init__()
93                self.const = 42
94                self.param = nn.Parameter(torch.randn(2, 2))
95
96            @torch.jit.script_method
97            def foo(self, x1, x2):
98                return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
99
100            @torch.jit.script_method
101            def forward(self, x1, x2):
102                fut = torch.jit.fork(self.foo, x1, x2)
103                y_hat = self.foo(x1, x2)
104                y = torch.jit.wait(fut)
105                return y, y_hat
106
107        x1 = torch.rand(3, 4)
108        x2 = torch.rand(5, 6)
109
110        m = Mod()
111
112        with torch.jit.optimized_execution(False):
113            y, y_hat = m.forward(x1, x2)
114
115        self.assertEqual(y, y_hat)
116
117    def test_async_script_nested(self):
118        @torch.jit.script
119        def foo(x):
120            return torch.neg(x), x
121
122        x = torch.rand(3, 4)
123
124        @torch.jit.script
125        def wait_script(x):
126            fut = torch.jit._fork(foo, x)
127            y_hat = foo(x)
128            y = torch.jit._wait(fut)
129            return y, y_hat
130
131        @torch.jit.script
132        def wait_script_nest(x):
133            fut = torch.jit._fork(wait_script, x)
134            return torch.jit._wait(fut)
135
136        y, y_hat = wait_script_nest(x)
137
138        self.assertEqual(y, y_hat)
139
140    def test_async_script_no_script_mod(self):
141        x = torch.rand(3, 4)
142
143        with self.assertRaisesRegexWithHighlight(
144            RuntimeError, "cannot call a value", "torch.jit._fork(x"
145        ):
146
147            @torch.jit.script
148            def wait_script(x):
149                fut = torch.jit._fork(x)
150                return fut
151
152    def test_async_script_multi_waits(self):
153        @torch.jit.script
154        def foo(x):
155            return torch.neg(x).t() + x
156
157        @torch.jit.script
158        def wait_script(x):
159            fut = torch.jit._fork(foo, x)
160
161            # wait twice on the same future
162            y1 = torch.jit._wait(fut)
163            y2 = torch.jit._wait(fut)
164            return y1, y2
165
166        x = torch.rand(2, 2)
167        y1, y2 = wait_script(x)
168        self.assertEqual(y1, y2)
169
170    def test_async_script_multi_forks(self):
171        @torch.jit.script
172        def foo1(x):
173            return torch.neg(x).t() + x
174
175        @torch.jit.script
176        def foo2(x, y):
177            return torch.neg(x).t() + x + torch.neg(y).t()
178
179        @torch.jit.script
180        def foo3(x, y, z):
181            return torch.neg(z).t() + y.t() + x
182
183        x1 = torch.rand(10, 10)
184        x2 = torch.rand(10, 10)
185        x3 = torch.rand(10, 10)
186
187        @torch.jit.script
188        def wait_script(x1, x2, x3):
189            f1 = torch.jit._fork(foo1, x1)
190            f2 = torch.jit._fork(foo2, x1, x2)
191            f3 = torch.jit._fork(foo3, x1, x2, x3)
192            f4 = torch.jit._fork(foo1, x2)
193            f5 = torch.jit._fork(foo2, x2, x3)
194
195            # ignore some forks
196            y1 = torch.jit._wait(f1)
197            y2 = torch.jit._wait(f2)
198            y3 = torch.jit._wait(f3)
199
200            return y1, y2, y3
201
202        y1, y2, y3 = wait_script(x1, x2, x3)
203        self.assertEqual(y1, foo1(x1))
204        self.assertEqual(y2, foo2(x1, x2))
205        self.assertEqual(y3, foo3(x1, x2, x3))
206
207    def test_async_kwargs(self):
208        def foo(x1, x2):
209            return 2 * x1 + x2
210
211        x1 = torch.rand(3, 4)
212        x2 = torch.rand(3, 4)
213        y_hat = foo(x1, x2)
214
215        # Cover tracing and bare functions with permutations of args, kwargs
216        for func in [
217            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)),
218            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)),
219            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)),
220            lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)),
221        ]:
222            for wrapper in [
223                func,
224                torch.jit.trace(func, (x1, x2)),
225            ]:
226                self.assertEqual(wrapper(x1, x2), y_hat)
227                self.assertEqual(wrapper(x1, x2=x2), y_hat)
228                self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
229                self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
230
231        # Cover scripting
232        @torch.jit.script
233        def foo_script_args(x1, x2):
234            return torch.jit._wait(torch.jit._fork(foo, x1, x2))
235
236        @torch.jit.script
237        def foo_script_kwargs(x1, x2):
238            return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2))
239
240        for wrapper in [
241            foo_script_args,
242            foo_script_kwargs,
243        ]:
244            self.assertEqual(wrapper(x1, x2), y_hat)
245            self.assertEqual(wrapper(x1, x2=x2), y_hat)
246            self.assertEqual(wrapper(x1=x1, x2=x2), y_hat)
247            self.assertEqual(wrapper(x2=x2, x1=x1), y_hat)
248
249    @_inline_everything
250    def test_async_script_trace(self):
251        class Traced(nn.Module):
252            def forward(self, x):
253                return (torch.neg(x), x)
254
255        class Mod(torch.jit.ScriptModule):
256            def __init__(self) -> None:
257                super().__init__()
258                x = torch.rand(3, 3)
259                self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
260
261            @torch.jit.script_method
262            def forward(
263                self, x: Tensor
264            ) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]:
265                future1 = torch.jit._fork(self.traced, x)
266                future2 = torch.jit._fork(torch.neg, x)
267
268                tensor_tuple = torch.jit._wait(future1)
269                tensor_single = torch.jit._wait(future2)
270
271                tensor_list = []
272                tensor_list.append(tensor_tuple[0])
273                tensor_list.append(tensor_single)
274
275                # return a nested structure of tensors
276                return (tensor_list, tensor_tuple, tensor_tuple[1])
277
278        class TupleCl(nn.Module):
279            def __init__(self) -> None:
280                super().__init__()
281                self.module = Mod()
282
283            def forward(self, x):
284                z = torch.neg(x)
285                y = self.module(x)
286                list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
287                return tuple(list)
288
289        x = torch.rand(3, 3)
290        module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
291
292        # Make sure we have forks
293        self.assertGraphContainsExactly(
294            module.graph, kind="prim::fork", num_kind_nodes=2
295        )
296        # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
297        self.assertGraphContainsExactly(
298            module.graph, kind="aten::neg", num_kind_nodes=1
299        )
300        self.assertGraphContainsExactly(
301            module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True
302        )
303
304        y = torch.neg(x)
305        self.assertEqual(module(x), (y, y, y, y, x, x))
306
307    def test_async_script_error(self):
308        x = torch.rand(3, 4)
309
310        @torch.jit.script
311        def foo(x):
312            # error here
313            return x.t() + x
314
315        @torch.jit.script
316        def wait_script(x):
317            fut = torch.jit._fork(foo, x)
318            return torch.jit._wait(fut)
319
320        @torch.jit.script
321        def wait_script_nest(x):
322            fut = torch.jit._fork(wait_script, x)
323            return torch.jit._wait(fut)
324
325        # no future
326        error_msg = "The size.*must match the size of tensor"
327        with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"):
328            foo(x)
329
330        # one future
331        with self.assertRaisesRegexWithHighlight(
332            Exception, error_msg, "torch.jit._fork(foo, x"
333        ):
334            wait_script(x)
335
336        # two futures with a different error
337        x = torch.rand(3, 4, 5)
338        with self.assertRaisesRegexWithHighlight(
339            Exception,
340            "expects a tensor with <= 2 dimensions",
341            "torch.jit._fork(wait_script, x",
342        ):
343            wait_script_nest(x)
344
345    def test_async_grad_guard_with_grad(self):
346        @torch.jit.script
347        def foo(x):
348            y = x * 2
349            return y.requires_grad
350
351        @torch.jit.script
352        def bar(x):
353            fut = torch.jit._fork(foo, x)
354            requires_grad_in_fork = torch.jit._wait(fut)
355            z = x * 2
356            return (requires_grad_in_fork, z.requires_grad)
357
358        x = torch.randn(3, requires_grad=True)
359
360        with torch.enable_grad():
361            (inside_fork, after_wait) = bar(x)
362
363        self.assertEqual(inside_fork, True)
364        self.assertEqual(after_wait, True)
365
366    def test_async_grad_guard_no_grad(self):
367        @torch.jit.script
368        def foo(x):
369            y = x * 2
370            return y.requires_grad
371
372        @torch.jit.script
373        def bar(x):
374            fut = torch.jit._fork(foo, x)
375            requires_grad_in_fork = torch.jit._wait(fut)
376            z = x * 2
377            return (requires_grad_in_fork, z.requires_grad)
378
379        x = torch.randn(3, requires_grad=True)
380
381        with torch.no_grad():
382            (inside_fork, after_wait) = bar(x)
383
384        self.assertEqual(inside_fork, False)
385        self.assertEqual(after_wait, False)
386
387    def test_trace_fork_wait(self):
388        def fork_body(x):
389            return x.neg(), x.neg() + 1
390
391        def fn(x):
392            fut = torch.jit._fork(fork_body, x)
393            vals = torch.jit._wait(fut)
394            return vals[0], vals[1], x - 1
395
396        traced = torch.jit.trace(fn, (torch.rand(3, 4),))
397        x = torch.rand(3, 4)
398        self.assertEqual(fn(x), traced(x))
399
400        self.assertGraphContainsExactly(
401            traced.graph, kind="prim::fork", num_kind_nodes=1
402        )
403        self.assertGraphContainsExactly(
404            traced.graph, kind="aten::wait", num_kind_nodes=1
405        )
406        self.assertGraphContainsExactly(
407            traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True
408        )
409
410    def test_trace_fork_wait_leaking(self):
411        my_list = []
412
413        def fork_body(x):
414            my_list.append(x + 1)
415            return x + 1
416
417        def fn(x):
418            fut = torch.jit._fork(fork_body, x)
419            val = torch.jit._wait(fut)
420            return my_list[0]
421
422        with self.assertRaisesRegexWithHighlight(
423            RuntimeError,
424            "did not have observable data dependence with trace inputs; "
425            "this probably indicates your program cannot be understood "
426            "by the tracer.",
427            "",
428        ):
429            traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
430
431    def test_trace_fork_wait_inline(self):
432        def fork_body(x):
433            return x + 1, x + 2
434
435        def fn(x):
436            fut = torch.jit._fork(fork_body, x)
437            val = torch.jit._wait(fut)
438            return val[1]
439
440        traced = torch.jit.trace(fn, (torch.rand(3, 4),))
441        torch._C._jit_pass_inline_fork_wait(traced.graph)
442        self.assertGraphContainsExactly(
443            traced.graph, kind="prim::fork", num_kind_nodes=0
444        )
445        self.assertGraphContainsExactly(
446            traced.graph, kind="aten::wait", num_kind_nodes=0
447        )
448        self.assertGraphContainsExactly(
449            traced.graph, kind="aten::add", num_kind_nodes=2
450        )
451
452    def test_trace_fork_wait_list_modulecalls(self):
453        def add_one(input):
454            return input + torch.ones(input.size())
455
456        class TestListFutureModule(nn.Module):
457            def forward(self, input):
458                input_list = []
459                for i in range(3):
460                    input_list.append(input)
461
462                fut_list: List[Future[torch.Tensor]] = []
463                for input_tensor in input_list:
464                    fut_list.append(torch.jit._fork(add_one, input_tensor))
465                # return list[future[tensor]] here to ensure tracing
466                # module calls return the correct types
467                return fut_list
468
469        class TestModuleWrapper(nn.Module):
470            def __init__(self) -> None:
471                super().__init__()
472                self.list_fut_mod = TestListFutureModule()
473
474            def forward(self, input):
475                fut_list = self.list_fut_mod(input)
476                res = input
477                for fut in fut_list:
478                    res = res + fut.wait()
479                return res
480
481        self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),))
482
483    def test_trace_modulecalls_with_different_output_types(self):
484        def add_one(input):
485            return input + torch.ones(input.size())
486
487        class DifferentOutputModule(nn.Module):
488            def forward(self, input):
489                fut_res = torch.jit._fork(add_one, (input))
490
491                # return different types from module call
492                return input, fut_res
493
494        class TestModule(nn.Module):
495            def __init__(self) -> None:
496                super().__init__()
497                self.gen_output = DifferentOutputModule()
498
499            def forward(self, input):
500                res, fut_res = self.gen_output(input)
501                res = res + fut_res.wait()
502                return res
503
504        self.checkTrace(TestModule(), (torch.randn(5, 5),))
505
506    def test_no_future_subtype_message(self):
507        with self.assertRaisesRegexWithHighlight(
508            RuntimeError, "Future without a contained type", ""
509        ):
510
511            @torch.jit.script
512            def forward(self, x):
513                futs = torch.jit.annotate(List[torch.jit.Future], [])
514
515    def test_future_subtyping(self):
516        """
517        Test that futures subtype each other properly.
518        """
519
520        # Successful subtyping.
521        def returns_int(x: int) -> int:
522            return x + x + 1
523
524        def returns_future_any(x: int) -> torch.jit.Future[Any]:
525            return torch.jit._fork(returns_int, (x))
526
527        @torch.jit.script
528        def fn_int(x: int) -> Any:
529            fut = returns_future_any(x)
530            return fut.wait()
531
532        # Unsuccessful subtyping.
533        with self.assertRaisesRegexWithHighlight(
534            RuntimeError,
535            r"was annotated as having type Future\[float\] but is actually of type Future\[int\]",
536            "fut = returns_future_float(x",
537        ):
538
539            def returns_future_float(x: int) -> torch.jit.Future[float]:
540                return torch.jit._fork(returns_int, (x))
541
542            @torch.jit.script
543            def fn_float(x: int) -> Any:
544                fut = returns_future_float(x)
545                return fut.wait()
546
547
548if __name__ == "__main__":
549    raise RuntimeError(
550        "This test file is not meant to be run directly, use:\n\n"
551        "\tpython test/test_jit.py TESTNAME\n\n"
552        "instead."
553    )
554