# Owner(s): ["oncall: jit"] import os import sys from typing import Any, Tuple import torch import torch.nn as nn # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) from typing import List from torch import Tensor from torch.jit import Future from torch.testing._internal.jit_utils import _inline_everything, JitTestCase class TestAsync(JitTestCase): def test_async_python(self): @torch.jit.script def foo(x): return torch.neg(x) x = torch.rand(3, 4) fut = torch.jit.fork(foo, x) y_hat = foo(x) y = torch.jit.wait(fut) # assert nothing; only to make sure the fake python path works def test_async_future_type_python(self): def foo(inp): futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], []) for i in range(5): futures.append(torch.jit.fork(lambda x: x, inp)) all_outputs = [] for future in futures: all_outputs.append(torch.jit.wait(future)) return all_outputs # assert nothing, just to make sure python type parsing works foo(torch.randn(3, 4)) def test_async_parsing(self): @torch.jit.script def foo(x: Tensor) -> List[Tensor]: return [torch.neg(x), x.t()] @torch.jit.script def bar(x): futures = torch.jit.annotate(List[Future[List[Tensor]]], []) for _ in range(3): future = torch.jit.annotate( Future[List[Tensor]], torch.jit.fork(foo, x) ) futures.append(future) output = torch.jit.annotate(List[List[Tensor]], []) for i in range(3): output.append(torch.jit.wait(futures[i])) return output x = torch.rand(3, 3) result = bar(x) self.assertEqual(len(result), 3) def test_async_script(self): @torch.jit.script def foo(x): return torch.neg(x), x x = torch.rand(3, 4) @torch.jit.script def wait_script(x): fut = torch.jit.fork(foo, x) y_hat = foo(x) y = torch.jit.wait(fut) return y, y_hat y, y_hat = wait_script(x) self.assertEqual(y, y_hat) def test_async_script_capture(self): class Mod(torch.jit.ScriptModule): __constants__ = ["const"] def __init__(self) -> None: super().__init__() self.const = 42 self.param = nn.Parameter(torch.randn(2, 2)) @torch.jit.script_method def foo(self, x1, x2): return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param @torch.jit.script_method def forward(self, x1, x2): fut = torch.jit.fork(self.foo, x1, x2) y_hat = self.foo(x1, x2) y = torch.jit.wait(fut) return y, y_hat x1 = torch.rand(3, 4) x2 = torch.rand(5, 6) m = Mod() with torch.jit.optimized_execution(False): y, y_hat = m.forward(x1, x2) self.assertEqual(y, y_hat) def test_async_script_nested(self): @torch.jit.script def foo(x): return torch.neg(x), x x = torch.rand(3, 4) @torch.jit.script def wait_script(x): fut = torch.jit._fork(foo, x) y_hat = foo(x) y = torch.jit._wait(fut) return y, y_hat @torch.jit.script def wait_script_nest(x): fut = torch.jit._fork(wait_script, x) return torch.jit._wait(fut) y, y_hat = wait_script_nest(x) self.assertEqual(y, y_hat) def test_async_script_no_script_mod(self): x = torch.rand(3, 4) with self.assertRaisesRegexWithHighlight( RuntimeError, "cannot call a value", "torch.jit._fork(x" ): @torch.jit.script def wait_script(x): fut = torch.jit._fork(x) return fut def test_async_script_multi_waits(self): @torch.jit.script def foo(x): return torch.neg(x).t() + x @torch.jit.script def wait_script(x): fut = torch.jit._fork(foo, x) # wait twice on the same future y1 = torch.jit._wait(fut) y2 = torch.jit._wait(fut) return y1, y2 x = torch.rand(2, 2) y1, y2 = wait_script(x) self.assertEqual(y1, y2) def test_async_script_multi_forks(self): @torch.jit.script def foo1(x): return torch.neg(x).t() + x @torch.jit.script def foo2(x, y): return torch.neg(x).t() + x + torch.neg(y).t() @torch.jit.script def foo3(x, y, z): return torch.neg(z).t() + y.t() + x x1 = torch.rand(10, 10) x2 = torch.rand(10, 10) x3 = torch.rand(10, 10) @torch.jit.script def wait_script(x1, x2, x3): f1 = torch.jit._fork(foo1, x1) f2 = torch.jit._fork(foo2, x1, x2) f3 = torch.jit._fork(foo3, x1, x2, x3) f4 = torch.jit._fork(foo1, x2) f5 = torch.jit._fork(foo2, x2, x3) # ignore some forks y1 = torch.jit._wait(f1) y2 = torch.jit._wait(f2) y3 = torch.jit._wait(f3) return y1, y2, y3 y1, y2, y3 = wait_script(x1, x2, x3) self.assertEqual(y1, foo1(x1)) self.assertEqual(y2, foo2(x1, x2)) self.assertEqual(y3, foo3(x1, x2, x3)) def test_async_kwargs(self): def foo(x1, x2): return 2 * x1 + x2 x1 = torch.rand(3, 4) x2 = torch.rand(3, 4) y_hat = foo(x1, x2) # Cover tracing and bare functions with permutations of args, kwargs for func in [ lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)), lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)), lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)), lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)), ]: for wrapper in [ func, torch.jit.trace(func, (x1, x2)), ]: self.assertEqual(wrapper(x1, x2), y_hat) self.assertEqual(wrapper(x1, x2=x2), y_hat) self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) # Cover scripting @torch.jit.script def foo_script_args(x1, x2): return torch.jit._wait(torch.jit._fork(foo, x1, x2)) @torch.jit.script def foo_script_kwargs(x1, x2): return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)) for wrapper in [ foo_script_args, foo_script_kwargs, ]: self.assertEqual(wrapper(x1, x2), y_hat) self.assertEqual(wrapper(x1, x2=x2), y_hat) self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) @_inline_everything def test_async_script_trace(self): class Traced(nn.Module): def forward(self, x): return (torch.neg(x), x) class Mod(torch.jit.ScriptModule): def __init__(self) -> None: super().__init__() x = torch.rand(3, 3) self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) @torch.jit.script_method def forward( self, x: Tensor ) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: future1 = torch.jit._fork(self.traced, x) future2 = torch.jit._fork(torch.neg, x) tensor_tuple = torch.jit._wait(future1) tensor_single = torch.jit._wait(future2) tensor_list = [] tensor_list.append(tensor_tuple[0]) tensor_list.append(tensor_single) # return a nested structure of tensors return (tensor_list, tensor_tuple, tensor_tuple[1]) class TupleCl(nn.Module): def __init__(self) -> None: super().__init__() self.module = Mod() def forward(self, x): z = torch.neg(x) y = self.module(x) list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]] return tuple(list) x = torch.rand(3, 3) module = torch.jit.trace(TupleCl(), (x), _force_outplace=True) # Make sure we have forks self.assertGraphContainsExactly( module.graph, kind="prim::fork", num_kind_nodes=2 ) # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs self.assertGraphContainsExactly( module.graph, kind="aten::neg", num_kind_nodes=1 ) self.assertGraphContainsExactly( module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True ) y = torch.neg(x) self.assertEqual(module(x), (y, y, y, y, x, x)) def test_async_script_error(self): x = torch.rand(3, 4) @torch.jit.script def foo(x): # error here return x.t() + x @torch.jit.script def wait_script(x): fut = torch.jit._fork(foo, x) return torch.jit._wait(fut) @torch.jit.script def wait_script_nest(x): fut = torch.jit._fork(wait_script, x) return torch.jit._wait(fut) # no future error_msg = "The size.*must match the size of tensor" with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"): foo(x) # one future with self.assertRaisesRegexWithHighlight( Exception, error_msg, "torch.jit._fork(foo, x" ): wait_script(x) # two futures with a different error x = torch.rand(3, 4, 5) with self.assertRaisesRegexWithHighlight( Exception, "expects a tensor with <= 2 dimensions", "torch.jit._fork(wait_script, x", ): wait_script_nest(x) def test_async_grad_guard_with_grad(self): @torch.jit.script def foo(x): y = x * 2 return y.requires_grad @torch.jit.script def bar(x): fut = torch.jit._fork(foo, x) requires_grad_in_fork = torch.jit._wait(fut) z = x * 2 return (requires_grad_in_fork, z.requires_grad) x = torch.randn(3, requires_grad=True) with torch.enable_grad(): (inside_fork, after_wait) = bar(x) self.assertEqual(inside_fork, True) self.assertEqual(after_wait, True) def test_async_grad_guard_no_grad(self): @torch.jit.script def foo(x): y = x * 2 return y.requires_grad @torch.jit.script def bar(x): fut = torch.jit._fork(foo, x) requires_grad_in_fork = torch.jit._wait(fut) z = x * 2 return (requires_grad_in_fork, z.requires_grad) x = torch.randn(3, requires_grad=True) with torch.no_grad(): (inside_fork, after_wait) = bar(x) self.assertEqual(inside_fork, False) self.assertEqual(after_wait, False) def test_trace_fork_wait(self): def fork_body(x): return x.neg(), x.neg() + 1 def fn(x): fut = torch.jit._fork(fork_body, x) vals = torch.jit._wait(fut) return vals[0], vals[1], x - 1 traced = torch.jit.trace(fn, (torch.rand(3, 4),)) x = torch.rand(3, 4) self.assertEqual(fn(x), traced(x)) self.assertGraphContainsExactly( traced.graph, kind="prim::fork", num_kind_nodes=1 ) self.assertGraphContainsExactly( traced.graph, kind="aten::wait", num_kind_nodes=1 ) self.assertGraphContainsExactly( traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True ) def test_trace_fork_wait_leaking(self): my_list = [] def fork_body(x): my_list.append(x + 1) return x + 1 def fn(x): fut = torch.jit._fork(fork_body, x) val = torch.jit._wait(fut) return my_list[0] with self.assertRaisesRegexWithHighlight( RuntimeError, "did not have observable data dependence with trace inputs; " "this probably indicates your program cannot be understood " "by the tracer.", "", ): traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False) def test_trace_fork_wait_inline(self): def fork_body(x): return x + 1, x + 2 def fn(x): fut = torch.jit._fork(fork_body, x) val = torch.jit._wait(fut) return val[1] traced = torch.jit.trace(fn, (torch.rand(3, 4),)) torch._C._jit_pass_inline_fork_wait(traced.graph) self.assertGraphContainsExactly( traced.graph, kind="prim::fork", num_kind_nodes=0 ) self.assertGraphContainsExactly( traced.graph, kind="aten::wait", num_kind_nodes=0 ) self.assertGraphContainsExactly( traced.graph, kind="aten::add", num_kind_nodes=2 ) def test_trace_fork_wait_list_modulecalls(self): def add_one(input): return input + torch.ones(input.size()) class TestListFutureModule(nn.Module): def forward(self, input): input_list = [] for i in range(3): input_list.append(input) fut_list: List[Future[torch.Tensor]] = [] for input_tensor in input_list: fut_list.append(torch.jit._fork(add_one, input_tensor)) # return list[future[tensor]] here to ensure tracing # module calls return the correct types return fut_list class TestModuleWrapper(nn.Module): def __init__(self) -> None: super().__init__() self.list_fut_mod = TestListFutureModule() def forward(self, input): fut_list = self.list_fut_mod(input) res = input for fut in fut_list: res = res + fut.wait() return res self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),)) def test_trace_modulecalls_with_different_output_types(self): def add_one(input): return input + torch.ones(input.size()) class DifferentOutputModule(nn.Module): def forward(self, input): fut_res = torch.jit._fork(add_one, (input)) # return different types from module call return input, fut_res class TestModule(nn.Module): def __init__(self) -> None: super().__init__() self.gen_output = DifferentOutputModule() def forward(self, input): res, fut_res = self.gen_output(input) res = res + fut_res.wait() return res self.checkTrace(TestModule(), (torch.randn(5, 5),)) def test_no_future_subtype_message(self): with self.assertRaisesRegexWithHighlight( RuntimeError, "Future without a contained type", "" ): @torch.jit.script def forward(self, x): futs = torch.jit.annotate(List[torch.jit.Future], []) def test_future_subtyping(self): """ Test that futures subtype each other properly. """ # Successful subtyping. def returns_int(x: int) -> int: return x + x + 1 def returns_future_any(x: int) -> torch.jit.Future[Any]: return torch.jit._fork(returns_int, (x)) @torch.jit.script def fn_int(x: int) -> Any: fut = returns_future_any(x) return fut.wait() # Unsuccessful subtyping. with self.assertRaisesRegexWithHighlight( RuntimeError, r"was annotated as having type Future\[float\] but is actually of type Future\[int\]", "fut = returns_future_float(x", ): def returns_future_float(x: int) -> torch.jit.Future[float]: return torch.jit._fork(returns_int, (x)) @torch.jit.script def fn_float(x: int) -> Any: fut = returns_future_float(x) return fut.wait() if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." )