1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport os 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Tuple 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker# Make the helper files in test/ importable 12*da0073e9SAndroid Build Coastguard Workerpytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13*da0073e9SAndroid Build Coastguard Workersys.path.append(pytorch_test_dir) 14*da0073e9SAndroid Build Coastguard Workerfrom typing import List 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 17*da0073e9SAndroid Build Coastguard Workerfrom torch.jit import Future 18*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import _inline_everything, JitTestCase 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerclass TestAsync(JitTestCase): 22*da0073e9SAndroid Build Coastguard Worker def test_async_python(self): 23*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 24*da0073e9SAndroid Build Coastguard Worker def foo(x): 25*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 28*da0073e9SAndroid Build Coastguard Worker fut = torch.jit.fork(foo, x) 29*da0073e9SAndroid Build Coastguard Worker y_hat = foo(x) 30*da0073e9SAndroid Build Coastguard Worker y = torch.jit.wait(fut) 31*da0073e9SAndroid Build Coastguard Worker # assert nothing; only to make sure the fake python path works 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker def test_async_future_type_python(self): 34*da0073e9SAndroid Build Coastguard Worker def foo(inp): 35*da0073e9SAndroid Build Coastguard Worker futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], []) 36*da0073e9SAndroid Build Coastguard Worker for i in range(5): 37*da0073e9SAndroid Build Coastguard Worker futures.append(torch.jit.fork(lambda x: x, inp)) 38*da0073e9SAndroid Build Coastguard Worker all_outputs = [] 39*da0073e9SAndroid Build Coastguard Worker for future in futures: 40*da0073e9SAndroid Build Coastguard Worker all_outputs.append(torch.jit.wait(future)) 41*da0073e9SAndroid Build Coastguard Worker return all_outputs 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker # assert nothing, just to make sure python type parsing works 44*da0073e9SAndroid Build Coastguard Worker foo(torch.randn(3, 4)) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def test_async_parsing(self): 47*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 48*da0073e9SAndroid Build Coastguard Worker def foo(x: Tensor) -> List[Tensor]: 49*da0073e9SAndroid Build Coastguard Worker return [torch.neg(x), x.t()] 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 52*da0073e9SAndroid Build Coastguard Worker def bar(x): 53*da0073e9SAndroid Build Coastguard Worker futures = torch.jit.annotate(List[Future[List[Tensor]]], []) 54*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 55*da0073e9SAndroid Build Coastguard Worker future = torch.jit.annotate( 56*da0073e9SAndroid Build Coastguard Worker Future[List[Tensor]], torch.jit.fork(foo, x) 57*da0073e9SAndroid Build Coastguard Worker ) 58*da0073e9SAndroid Build Coastguard Worker futures.append(future) 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker output = torch.jit.annotate(List[List[Tensor]], []) 61*da0073e9SAndroid Build Coastguard Worker for i in range(3): 62*da0073e9SAndroid Build Coastguard Worker output.append(torch.jit.wait(futures[i])) 63*da0073e9SAndroid Build Coastguard Worker return output 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 3) 66*da0073e9SAndroid Build Coastguard Worker result = bar(x) 67*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(result), 3) 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker def test_async_script(self): 70*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 71*da0073e9SAndroid Build Coastguard Worker def foo(x): 72*da0073e9SAndroid Build Coastguard Worker return torch.neg(x), x 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 77*da0073e9SAndroid Build Coastguard Worker def wait_script(x): 78*da0073e9SAndroid Build Coastguard Worker fut = torch.jit.fork(foo, x) 79*da0073e9SAndroid Build Coastguard Worker y_hat = foo(x) 80*da0073e9SAndroid Build Coastguard Worker y = torch.jit.wait(fut) 81*da0073e9SAndroid Build Coastguard Worker return y, y_hat 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker y, y_hat = wait_script(x) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_hat) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_async_script_capture(self): 88*da0073e9SAndroid Build Coastguard Worker class Mod(torch.jit.ScriptModule): 89*da0073e9SAndroid Build Coastguard Worker __constants__ = ["const"] 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 92*da0073e9SAndroid Build Coastguard Worker super().__init__() 93*da0073e9SAndroid Build Coastguard Worker self.const = 42 94*da0073e9SAndroid Build Coastguard Worker self.param = nn.Parameter(torch.randn(2, 2)) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 97*da0073e9SAndroid Build Coastguard Worker def foo(self, x1, x2): 98*da0073e9SAndroid Build Coastguard Worker return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 101*da0073e9SAndroid Build Coastguard Worker def forward(self, x1, x2): 102*da0073e9SAndroid Build Coastguard Worker fut = torch.jit.fork(self.foo, x1, x2) 103*da0073e9SAndroid Build Coastguard Worker y_hat = self.foo(x1, x2) 104*da0073e9SAndroid Build Coastguard Worker y = torch.jit.wait(fut) 105*da0073e9SAndroid Build Coastguard Worker return y, y_hat 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(3, 4) 108*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(5, 6) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker m = Mod() 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 113*da0073e9SAndroid Build Coastguard Worker y, y_hat = m.forward(x1, x2) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_hat) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def test_async_script_nested(self): 118*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 119*da0073e9SAndroid Build Coastguard Worker def foo(x): 120*da0073e9SAndroid Build Coastguard Worker return torch.neg(x), x 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 125*da0073e9SAndroid Build Coastguard Worker def wait_script(x): 126*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, x) 127*da0073e9SAndroid Build Coastguard Worker y_hat = foo(x) 128*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 129*da0073e9SAndroid Build Coastguard Worker return y, y_hat 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 132*da0073e9SAndroid Build Coastguard Worker def wait_script_nest(x): 133*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(wait_script, x) 134*da0073e9SAndroid Build Coastguard Worker return torch.jit._wait(fut) 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker y, y_hat = wait_script_nest(x) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y_hat) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker def test_async_script_no_script_mod(self): 141*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 144*da0073e9SAndroid Build Coastguard Worker RuntimeError, "cannot call a value", "torch.jit._fork(x" 145*da0073e9SAndroid Build Coastguard Worker ): 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 148*da0073e9SAndroid Build Coastguard Worker def wait_script(x): 149*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(x) 150*da0073e9SAndroid Build Coastguard Worker return fut 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Worker def test_async_script_multi_waits(self): 153*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 154*da0073e9SAndroid Build Coastguard Worker def foo(x): 155*da0073e9SAndroid Build Coastguard Worker return torch.neg(x).t() + x 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 158*da0073e9SAndroid Build Coastguard Worker def wait_script(x): 159*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, x) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker # wait twice on the same future 162*da0073e9SAndroid Build Coastguard Worker y1 = torch.jit._wait(fut) 163*da0073e9SAndroid Build Coastguard Worker y2 = torch.jit._wait(fut) 164*da0073e9SAndroid Build Coastguard Worker return y1, y2 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 2) 167*da0073e9SAndroid Build Coastguard Worker y1, y2 = wait_script(x) 168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, y2) 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker def test_async_script_multi_forks(self): 171*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 172*da0073e9SAndroid Build Coastguard Worker def foo1(x): 173*da0073e9SAndroid Build Coastguard Worker return torch.neg(x).t() + x 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 176*da0073e9SAndroid Build Coastguard Worker def foo2(x, y): 177*da0073e9SAndroid Build Coastguard Worker return torch.neg(x).t() + x + torch.neg(y).t() 178*da0073e9SAndroid Build Coastguard Worker 179*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 180*da0073e9SAndroid Build Coastguard Worker def foo3(x, y, z): 181*da0073e9SAndroid Build Coastguard Worker return torch.neg(z).t() + y.t() + x 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(10, 10) 184*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(10, 10) 185*da0073e9SAndroid Build Coastguard Worker x3 = torch.rand(10, 10) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 188*da0073e9SAndroid Build Coastguard Worker def wait_script(x1, x2, x3): 189*da0073e9SAndroid Build Coastguard Worker f1 = torch.jit._fork(foo1, x1) 190*da0073e9SAndroid Build Coastguard Worker f2 = torch.jit._fork(foo2, x1, x2) 191*da0073e9SAndroid Build Coastguard Worker f3 = torch.jit._fork(foo3, x1, x2, x3) 192*da0073e9SAndroid Build Coastguard Worker f4 = torch.jit._fork(foo1, x2) 193*da0073e9SAndroid Build Coastguard Worker f5 = torch.jit._fork(foo2, x2, x3) 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker # ignore some forks 196*da0073e9SAndroid Build Coastguard Worker y1 = torch.jit._wait(f1) 197*da0073e9SAndroid Build Coastguard Worker y2 = torch.jit._wait(f2) 198*da0073e9SAndroid Build Coastguard Worker y3 = torch.jit._wait(f3) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker return y1, y2, y3 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker y1, y2, y3 = wait_script(x1, x2, x3) 203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y1, foo1(x1)) 204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y2, foo2(x1, x2)) 205*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y3, foo3(x1, x2, x3)) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker def test_async_kwargs(self): 208*da0073e9SAndroid Build Coastguard Worker def foo(x1, x2): 209*da0073e9SAndroid Build Coastguard Worker return 2 * x1 + x2 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker x1 = torch.rand(3, 4) 212*da0073e9SAndroid Build Coastguard Worker x2 = torch.rand(3, 4) 213*da0073e9SAndroid Build Coastguard Worker y_hat = foo(x1, x2) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker # Cover tracing and bare functions with permutations of args, kwargs 216*da0073e9SAndroid Build Coastguard Worker for func in [ 217*da0073e9SAndroid Build Coastguard Worker lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)), 218*da0073e9SAndroid Build Coastguard Worker lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)), 219*da0073e9SAndroid Build Coastguard Worker lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)), 220*da0073e9SAndroid Build Coastguard Worker lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)), 221*da0073e9SAndroid Build Coastguard Worker ]: 222*da0073e9SAndroid Build Coastguard Worker for wrapper in [ 223*da0073e9SAndroid Build Coastguard Worker func, 224*da0073e9SAndroid Build Coastguard Worker torch.jit.trace(func, (x1, x2)), 225*da0073e9SAndroid Build Coastguard Worker ]: 226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x1, x2), y_hat) 227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x1, x2=x2), y_hat) 228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) 229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker # Cover scripting 232*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 233*da0073e9SAndroid Build Coastguard Worker def foo_script_args(x1, x2): 234*da0073e9SAndroid Build Coastguard Worker return torch.jit._wait(torch.jit._fork(foo, x1, x2)) 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 237*da0073e9SAndroid Build Coastguard Worker def foo_script_kwargs(x1, x2): 238*da0073e9SAndroid Build Coastguard Worker return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker for wrapper in [ 241*da0073e9SAndroid Build Coastguard Worker foo_script_args, 242*da0073e9SAndroid Build Coastguard Worker foo_script_kwargs, 243*da0073e9SAndroid Build Coastguard Worker ]: 244*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x1, x2), y_hat) 245*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x1, x2=x2), y_hat) 246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) 247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker @_inline_everything 250*da0073e9SAndroid Build Coastguard Worker def test_async_script_trace(self): 251*da0073e9SAndroid Build Coastguard Worker class Traced(nn.Module): 252*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 253*da0073e9SAndroid Build Coastguard Worker return (torch.neg(x), x) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker class Mod(torch.jit.ScriptModule): 256*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 257*da0073e9SAndroid Build Coastguard Worker super().__init__() 258*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 3) 259*da0073e9SAndroid Build Coastguard Worker self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 262*da0073e9SAndroid Build Coastguard Worker def forward( 263*da0073e9SAndroid Build Coastguard Worker self, x: Tensor 264*da0073e9SAndroid Build Coastguard Worker ) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: 265*da0073e9SAndroid Build Coastguard Worker future1 = torch.jit._fork(self.traced, x) 266*da0073e9SAndroid Build Coastguard Worker future2 = torch.jit._fork(torch.neg, x) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker tensor_tuple = torch.jit._wait(future1) 269*da0073e9SAndroid Build Coastguard Worker tensor_single = torch.jit._wait(future2) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Worker tensor_list = [] 272*da0073e9SAndroid Build Coastguard Worker tensor_list.append(tensor_tuple[0]) 273*da0073e9SAndroid Build Coastguard Worker tensor_list.append(tensor_single) 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker # return a nested structure of tensors 276*da0073e9SAndroid Build Coastguard Worker return (tensor_list, tensor_tuple, tensor_tuple[1]) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker class TupleCl(nn.Module): 279*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 280*da0073e9SAndroid Build Coastguard Worker super().__init__() 281*da0073e9SAndroid Build Coastguard Worker self.module = Mod() 282*da0073e9SAndroid Build Coastguard Worker 283*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 284*da0073e9SAndroid Build Coastguard Worker z = torch.neg(x) 285*da0073e9SAndroid Build Coastguard Worker y = self.module(x) 286*da0073e9SAndroid Build Coastguard Worker list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]] 287*da0073e9SAndroid Build Coastguard Worker return tuple(list) 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 3) 290*da0073e9SAndroid Build Coastguard Worker module = torch.jit.trace(TupleCl(), (x), _force_outplace=True) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker # Make sure we have forks 293*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 294*da0073e9SAndroid Build Coastguard Worker module.graph, kind="prim::fork", num_kind_nodes=2 295*da0073e9SAndroid Build Coastguard Worker ) 296*da0073e9SAndroid Build Coastguard Worker # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs 297*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 298*da0073e9SAndroid Build Coastguard Worker module.graph, kind="aten::neg", num_kind_nodes=1 299*da0073e9SAndroid Build Coastguard Worker ) 300*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 301*da0073e9SAndroid Build Coastguard Worker module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True 302*da0073e9SAndroid Build Coastguard Worker ) 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker y = torch.neg(x) 305*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module(x), (y, y, y, y, x, x)) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker def test_async_script_error(self): 308*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 311*da0073e9SAndroid Build Coastguard Worker def foo(x): 312*da0073e9SAndroid Build Coastguard Worker # error here 313*da0073e9SAndroid Build Coastguard Worker return x.t() + x 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 316*da0073e9SAndroid Build Coastguard Worker def wait_script(x): 317*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, x) 318*da0073e9SAndroid Build Coastguard Worker return torch.jit._wait(fut) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 321*da0073e9SAndroid Build Coastguard Worker def wait_script_nest(x): 322*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(wait_script, x) 323*da0073e9SAndroid Build Coastguard Worker return torch.jit._wait(fut) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker # no future 326*da0073e9SAndroid Build Coastguard Worker error_msg = "The size.*must match the size of tensor" 327*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"): 328*da0073e9SAndroid Build Coastguard Worker foo(x) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker # one future 331*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 332*da0073e9SAndroid Build Coastguard Worker Exception, error_msg, "torch.jit._fork(foo, x" 333*da0073e9SAndroid Build Coastguard Worker ): 334*da0073e9SAndroid Build Coastguard Worker wait_script(x) 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker # two futures with a different error 337*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4, 5) 338*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 339*da0073e9SAndroid Build Coastguard Worker Exception, 340*da0073e9SAndroid Build Coastguard Worker "expects a tensor with <= 2 dimensions", 341*da0073e9SAndroid Build Coastguard Worker "torch.jit._fork(wait_script, x", 342*da0073e9SAndroid Build Coastguard Worker ): 343*da0073e9SAndroid Build Coastguard Worker wait_script_nest(x) 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker def test_async_grad_guard_with_grad(self): 346*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 347*da0073e9SAndroid Build Coastguard Worker def foo(x): 348*da0073e9SAndroid Build Coastguard Worker y = x * 2 349*da0073e9SAndroid Build Coastguard Worker return y.requires_grad 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 352*da0073e9SAndroid Build Coastguard Worker def bar(x): 353*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, x) 354*da0073e9SAndroid Build Coastguard Worker requires_grad_in_fork = torch.jit._wait(fut) 355*da0073e9SAndroid Build Coastguard Worker z = x * 2 356*da0073e9SAndroid Build Coastguard Worker return (requires_grad_in_fork, z.requires_grad) 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 361*da0073e9SAndroid Build Coastguard Worker (inside_fork, after_wait) = bar(x) 362*da0073e9SAndroid Build Coastguard Worker 363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inside_fork, True) 364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(after_wait, True) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def test_async_grad_guard_no_grad(self): 367*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 368*da0073e9SAndroid Build Coastguard Worker def foo(x): 369*da0073e9SAndroid Build Coastguard Worker y = x * 2 370*da0073e9SAndroid Build Coastguard Worker return y.requires_grad 371*da0073e9SAndroid Build Coastguard Worker 372*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 373*da0073e9SAndroid Build Coastguard Worker def bar(x): 374*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, x) 375*da0073e9SAndroid Build Coastguard Worker requires_grad_in_fork = torch.jit._wait(fut) 376*da0073e9SAndroid Build Coastguard Worker z = x * 2 377*da0073e9SAndroid Build Coastguard Worker return (requires_grad_in_fork, z.requires_grad) 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 382*da0073e9SAndroid Build Coastguard Worker (inside_fork, after_wait) = bar(x) 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inside_fork, False) 385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(after_wait, False) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def test_trace_fork_wait(self): 388*da0073e9SAndroid Build Coastguard Worker def fork_body(x): 389*da0073e9SAndroid Build Coastguard Worker return x.neg(), x.neg() + 1 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker def fn(x): 392*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(fork_body, x) 393*da0073e9SAndroid Build Coastguard Worker vals = torch.jit._wait(fut) 394*da0073e9SAndroid Build Coastguard Worker return vals[0], vals[1], x - 1 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(fn, (torch.rand(3, 4),)) 397*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), traced(x)) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 401*da0073e9SAndroid Build Coastguard Worker traced.graph, kind="prim::fork", num_kind_nodes=1 402*da0073e9SAndroid Build Coastguard Worker ) 403*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 404*da0073e9SAndroid Build Coastguard Worker traced.graph, kind="aten::wait", num_kind_nodes=1 405*da0073e9SAndroid Build Coastguard Worker ) 406*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 407*da0073e9SAndroid Build Coastguard Worker traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker def test_trace_fork_wait_leaking(self): 411*da0073e9SAndroid Build Coastguard Worker my_list = [] 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker def fork_body(x): 414*da0073e9SAndroid Build Coastguard Worker my_list.append(x + 1) 415*da0073e9SAndroid Build Coastguard Worker return x + 1 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker def fn(x): 418*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(fork_body, x) 419*da0073e9SAndroid Build Coastguard Worker val = torch.jit._wait(fut) 420*da0073e9SAndroid Build Coastguard Worker return my_list[0] 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 423*da0073e9SAndroid Build Coastguard Worker RuntimeError, 424*da0073e9SAndroid Build Coastguard Worker "did not have observable data dependence with trace inputs; " 425*da0073e9SAndroid Build Coastguard Worker "this probably indicates your program cannot be understood " 426*da0073e9SAndroid Build Coastguard Worker "by the tracer.", 427*da0073e9SAndroid Build Coastguard Worker "", 428*da0073e9SAndroid Build Coastguard Worker ): 429*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker def test_trace_fork_wait_inline(self): 432*da0073e9SAndroid Build Coastguard Worker def fork_body(x): 433*da0073e9SAndroid Build Coastguard Worker return x + 1, x + 2 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker def fn(x): 436*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(fork_body, x) 437*da0073e9SAndroid Build Coastguard Worker val = torch.jit._wait(fut) 438*da0073e9SAndroid Build Coastguard Worker return val[1] 439*da0073e9SAndroid Build Coastguard Worker 440*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(fn, (torch.rand(3, 4),)) 441*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_inline_fork_wait(traced.graph) 442*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 443*da0073e9SAndroid Build Coastguard Worker traced.graph, kind="prim::fork", num_kind_nodes=0 444*da0073e9SAndroid Build Coastguard Worker ) 445*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 446*da0073e9SAndroid Build Coastguard Worker traced.graph, kind="aten::wait", num_kind_nodes=0 447*da0073e9SAndroid Build Coastguard Worker ) 448*da0073e9SAndroid Build Coastguard Worker self.assertGraphContainsExactly( 449*da0073e9SAndroid Build Coastguard Worker traced.graph, kind="aten::add", num_kind_nodes=2 450*da0073e9SAndroid Build Coastguard Worker ) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker def test_trace_fork_wait_list_modulecalls(self): 453*da0073e9SAndroid Build Coastguard Worker def add_one(input): 454*da0073e9SAndroid Build Coastguard Worker return input + torch.ones(input.size()) 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker class TestListFutureModule(nn.Module): 457*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 458*da0073e9SAndroid Build Coastguard Worker input_list = [] 459*da0073e9SAndroid Build Coastguard Worker for i in range(3): 460*da0073e9SAndroid Build Coastguard Worker input_list.append(input) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker fut_list: List[Future[torch.Tensor]] = [] 463*da0073e9SAndroid Build Coastguard Worker for input_tensor in input_list: 464*da0073e9SAndroid Build Coastguard Worker fut_list.append(torch.jit._fork(add_one, input_tensor)) 465*da0073e9SAndroid Build Coastguard Worker # return list[future[tensor]] here to ensure tracing 466*da0073e9SAndroid Build Coastguard Worker # module calls return the correct types 467*da0073e9SAndroid Build Coastguard Worker return fut_list 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker class TestModuleWrapper(nn.Module): 470*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 471*da0073e9SAndroid Build Coastguard Worker super().__init__() 472*da0073e9SAndroid Build Coastguard Worker self.list_fut_mod = TestListFutureModule() 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 475*da0073e9SAndroid Build Coastguard Worker fut_list = self.list_fut_mod(input) 476*da0073e9SAndroid Build Coastguard Worker res = input 477*da0073e9SAndroid Build Coastguard Worker for fut in fut_list: 478*da0073e9SAndroid Build Coastguard Worker res = res + fut.wait() 479*da0073e9SAndroid Build Coastguard Worker return res 480*da0073e9SAndroid Build Coastguard Worker 481*da0073e9SAndroid Build Coastguard Worker self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),)) 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker def test_trace_modulecalls_with_different_output_types(self): 484*da0073e9SAndroid Build Coastguard Worker def add_one(input): 485*da0073e9SAndroid Build Coastguard Worker return input + torch.ones(input.size()) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker class DifferentOutputModule(nn.Module): 488*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 489*da0073e9SAndroid Build Coastguard Worker fut_res = torch.jit._fork(add_one, (input)) 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker # return different types from module call 492*da0073e9SAndroid Build Coastguard Worker return input, fut_res 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker class TestModule(nn.Module): 495*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 496*da0073e9SAndroid Build Coastguard Worker super().__init__() 497*da0073e9SAndroid Build Coastguard Worker self.gen_output = DifferentOutputModule() 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 500*da0073e9SAndroid Build Coastguard Worker res, fut_res = self.gen_output(input) 501*da0073e9SAndroid Build Coastguard Worker res = res + fut_res.wait() 502*da0073e9SAndroid Build Coastguard Worker return res 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker self.checkTrace(TestModule(), (torch.randn(5, 5),)) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker def test_no_future_subtype_message(self): 507*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 508*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Future without a contained type", "" 509*da0073e9SAndroid Build Coastguard Worker ): 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 512*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 513*da0073e9SAndroid Build Coastguard Worker futs = torch.jit.annotate(List[torch.jit.Future], []) 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker def test_future_subtyping(self): 516*da0073e9SAndroid Build Coastguard Worker """ 517*da0073e9SAndroid Build Coastguard Worker Test that futures subtype each other properly. 518*da0073e9SAndroid Build Coastguard Worker """ 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker # Successful subtyping. 521*da0073e9SAndroid Build Coastguard Worker def returns_int(x: int) -> int: 522*da0073e9SAndroid Build Coastguard Worker return x + x + 1 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker def returns_future_any(x: int) -> torch.jit.Future[Any]: 525*da0073e9SAndroid Build Coastguard Worker return torch.jit._fork(returns_int, (x)) 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 528*da0073e9SAndroid Build Coastguard Worker def fn_int(x: int) -> Any: 529*da0073e9SAndroid Build Coastguard Worker fut = returns_future_any(x) 530*da0073e9SAndroid Build Coastguard Worker return fut.wait() 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker # Unsuccessful subtyping. 533*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight( 534*da0073e9SAndroid Build Coastguard Worker RuntimeError, 535*da0073e9SAndroid Build Coastguard Worker r"was annotated as having type Future\[float\] but is actually of type Future\[int\]", 536*da0073e9SAndroid Build Coastguard Worker "fut = returns_future_float(x", 537*da0073e9SAndroid Build Coastguard Worker ): 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker def returns_future_float(x: int) -> torch.jit.Future[float]: 540*da0073e9SAndroid Build Coastguard Worker return torch.jit._fork(returns_int, (x)) 541*da0073e9SAndroid Build Coastguard Worker 542*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 543*da0073e9SAndroid Build Coastguard Worker def fn_float(x: int) -> Any: 544*da0073e9SAndroid Build Coastguard Worker fut = returns_future_float(x) 545*da0073e9SAndroid Build Coastguard Worker return fut.wait() 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 549*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 550*da0073e9SAndroid Build Coastguard Worker "This test file is not meant to be run directly, use:\n\n" 551*da0073e9SAndroid Build Coastguard Worker "\tpython test/test_jit.py TESTNAME\n\n" 552*da0073e9SAndroid Build Coastguard Worker "instead." 553*da0073e9SAndroid Build Coastguard Worker ) 554