1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5from typing import Any, Tuple 6 7import torch 8import torch.nn as nn 9 10 11# Make the helper files in test/ importable 12pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 13sys.path.append(pytorch_test_dir) 14from typing import List 15 16from torch import Tensor 17from torch.jit import Future 18from torch.testing._internal.jit_utils import _inline_everything, JitTestCase 19 20 21class TestAsync(JitTestCase): 22 def test_async_python(self): 23 @torch.jit.script 24 def foo(x): 25 return torch.neg(x) 26 27 x = torch.rand(3, 4) 28 fut = torch.jit.fork(foo, x) 29 y_hat = foo(x) 30 y = torch.jit.wait(fut) 31 # assert nothing; only to make sure the fake python path works 32 33 def test_async_future_type_python(self): 34 def foo(inp): 35 futures = torch.jit.annotate(List[torch.jit.Future[torch.Tensor]], []) 36 for i in range(5): 37 futures.append(torch.jit.fork(lambda x: x, inp)) 38 all_outputs = [] 39 for future in futures: 40 all_outputs.append(torch.jit.wait(future)) 41 return all_outputs 42 43 # assert nothing, just to make sure python type parsing works 44 foo(torch.randn(3, 4)) 45 46 def test_async_parsing(self): 47 @torch.jit.script 48 def foo(x: Tensor) -> List[Tensor]: 49 return [torch.neg(x), x.t()] 50 51 @torch.jit.script 52 def bar(x): 53 futures = torch.jit.annotate(List[Future[List[Tensor]]], []) 54 for _ in range(3): 55 future = torch.jit.annotate( 56 Future[List[Tensor]], torch.jit.fork(foo, x) 57 ) 58 futures.append(future) 59 60 output = torch.jit.annotate(List[List[Tensor]], []) 61 for i in range(3): 62 output.append(torch.jit.wait(futures[i])) 63 return output 64 65 x = torch.rand(3, 3) 66 result = bar(x) 67 self.assertEqual(len(result), 3) 68 69 def test_async_script(self): 70 @torch.jit.script 71 def foo(x): 72 return torch.neg(x), x 73 74 x = torch.rand(3, 4) 75 76 @torch.jit.script 77 def wait_script(x): 78 fut = torch.jit.fork(foo, x) 79 y_hat = foo(x) 80 y = torch.jit.wait(fut) 81 return y, y_hat 82 83 y, y_hat = wait_script(x) 84 85 self.assertEqual(y, y_hat) 86 87 def test_async_script_capture(self): 88 class Mod(torch.jit.ScriptModule): 89 __constants__ = ["const"] 90 91 def __init__(self) -> None: 92 super().__init__() 93 self.const = 42 94 self.param = nn.Parameter(torch.randn(2, 2)) 95 96 @torch.jit.script_method 97 def foo(self, x1, x2): 98 return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param 99 100 @torch.jit.script_method 101 def forward(self, x1, x2): 102 fut = torch.jit.fork(self.foo, x1, x2) 103 y_hat = self.foo(x1, x2) 104 y = torch.jit.wait(fut) 105 return y, y_hat 106 107 x1 = torch.rand(3, 4) 108 x2 = torch.rand(5, 6) 109 110 m = Mod() 111 112 with torch.jit.optimized_execution(False): 113 y, y_hat = m.forward(x1, x2) 114 115 self.assertEqual(y, y_hat) 116 117 def test_async_script_nested(self): 118 @torch.jit.script 119 def foo(x): 120 return torch.neg(x), x 121 122 x = torch.rand(3, 4) 123 124 @torch.jit.script 125 def wait_script(x): 126 fut = torch.jit._fork(foo, x) 127 y_hat = foo(x) 128 y = torch.jit._wait(fut) 129 return y, y_hat 130 131 @torch.jit.script 132 def wait_script_nest(x): 133 fut = torch.jit._fork(wait_script, x) 134 return torch.jit._wait(fut) 135 136 y, y_hat = wait_script_nest(x) 137 138 self.assertEqual(y, y_hat) 139 140 def test_async_script_no_script_mod(self): 141 x = torch.rand(3, 4) 142 143 with self.assertRaisesRegexWithHighlight( 144 RuntimeError, "cannot call a value", "torch.jit._fork(x" 145 ): 146 147 @torch.jit.script 148 def wait_script(x): 149 fut = torch.jit._fork(x) 150 return fut 151 152 def test_async_script_multi_waits(self): 153 @torch.jit.script 154 def foo(x): 155 return torch.neg(x).t() + x 156 157 @torch.jit.script 158 def wait_script(x): 159 fut = torch.jit._fork(foo, x) 160 161 # wait twice on the same future 162 y1 = torch.jit._wait(fut) 163 y2 = torch.jit._wait(fut) 164 return y1, y2 165 166 x = torch.rand(2, 2) 167 y1, y2 = wait_script(x) 168 self.assertEqual(y1, y2) 169 170 def test_async_script_multi_forks(self): 171 @torch.jit.script 172 def foo1(x): 173 return torch.neg(x).t() + x 174 175 @torch.jit.script 176 def foo2(x, y): 177 return torch.neg(x).t() + x + torch.neg(y).t() 178 179 @torch.jit.script 180 def foo3(x, y, z): 181 return torch.neg(z).t() + y.t() + x 182 183 x1 = torch.rand(10, 10) 184 x2 = torch.rand(10, 10) 185 x3 = torch.rand(10, 10) 186 187 @torch.jit.script 188 def wait_script(x1, x2, x3): 189 f1 = torch.jit._fork(foo1, x1) 190 f2 = torch.jit._fork(foo2, x1, x2) 191 f3 = torch.jit._fork(foo3, x1, x2, x3) 192 f4 = torch.jit._fork(foo1, x2) 193 f5 = torch.jit._fork(foo2, x2, x3) 194 195 # ignore some forks 196 y1 = torch.jit._wait(f1) 197 y2 = torch.jit._wait(f2) 198 y3 = torch.jit._wait(f3) 199 200 return y1, y2, y3 201 202 y1, y2, y3 = wait_script(x1, x2, x3) 203 self.assertEqual(y1, foo1(x1)) 204 self.assertEqual(y2, foo2(x1, x2)) 205 self.assertEqual(y3, foo3(x1, x2, x3)) 206 207 def test_async_kwargs(self): 208 def foo(x1, x2): 209 return 2 * x1 + x2 210 211 x1 = torch.rand(3, 4) 212 x2 = torch.rand(3, 4) 213 y_hat = foo(x1, x2) 214 215 # Cover tracing and bare functions with permutations of args, kwargs 216 for func in [ 217 lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2)), 218 lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1, x2=x2)), 219 lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)), 220 lambda x1, x2: torch.jit._wait(torch.jit._fork(foo, x2=x2, x1=x1)), 221 ]: 222 for wrapper in [ 223 func, 224 torch.jit.trace(func, (x1, x2)), 225 ]: 226 self.assertEqual(wrapper(x1, x2), y_hat) 227 self.assertEqual(wrapper(x1, x2=x2), y_hat) 228 self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) 229 self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) 230 231 # Cover scripting 232 @torch.jit.script 233 def foo_script_args(x1, x2): 234 return torch.jit._wait(torch.jit._fork(foo, x1, x2)) 235 236 @torch.jit.script 237 def foo_script_kwargs(x1, x2): 238 return torch.jit._wait(torch.jit._fork(foo, x1=x1, x2=x2)) 239 240 for wrapper in [ 241 foo_script_args, 242 foo_script_kwargs, 243 ]: 244 self.assertEqual(wrapper(x1, x2), y_hat) 245 self.assertEqual(wrapper(x1, x2=x2), y_hat) 246 self.assertEqual(wrapper(x1=x1, x2=x2), y_hat) 247 self.assertEqual(wrapper(x2=x2, x1=x1), y_hat) 248 249 @_inline_everything 250 def test_async_script_trace(self): 251 class Traced(nn.Module): 252 def forward(self, x): 253 return (torch.neg(x), x) 254 255 class Mod(torch.jit.ScriptModule): 256 def __init__(self) -> None: 257 super().__init__() 258 x = torch.rand(3, 3) 259 self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True) 260 261 @torch.jit.script_method 262 def forward( 263 self, x: Tensor 264 ) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]: 265 future1 = torch.jit._fork(self.traced, x) 266 future2 = torch.jit._fork(torch.neg, x) 267 268 tensor_tuple = torch.jit._wait(future1) 269 tensor_single = torch.jit._wait(future2) 270 271 tensor_list = [] 272 tensor_list.append(tensor_tuple[0]) 273 tensor_list.append(tensor_single) 274 275 # return a nested structure of tensors 276 return (tensor_list, tensor_tuple, tensor_tuple[1]) 277 278 class TupleCl(nn.Module): 279 def __init__(self) -> None: 280 super().__init__() 281 self.module = Mod() 282 283 def forward(self, x): 284 z = torch.neg(x) 285 y = self.module(x) 286 list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]] 287 return tuple(list) 288 289 x = torch.rand(3, 3) 290 module = torch.jit.trace(TupleCl(), (x), _force_outplace=True) 291 292 # Make sure we have forks 293 self.assertGraphContainsExactly( 294 module.graph, kind="prim::fork", num_kind_nodes=2 295 ) 296 # Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs 297 self.assertGraphContainsExactly( 298 module.graph, kind="aten::neg", num_kind_nodes=1 299 ) 300 self.assertGraphContainsExactly( 301 module.graph, kind="aten::neg", num_kind_nodes=3, consider_subgraphs=True 302 ) 303 304 y = torch.neg(x) 305 self.assertEqual(module(x), (y, y, y, y, x, x)) 306 307 def test_async_script_error(self): 308 x = torch.rand(3, 4) 309 310 @torch.jit.script 311 def foo(x): 312 # error here 313 return x.t() + x 314 315 @torch.jit.script 316 def wait_script(x): 317 fut = torch.jit._fork(foo, x) 318 return torch.jit._wait(fut) 319 320 @torch.jit.script 321 def wait_script_nest(x): 322 fut = torch.jit._fork(wait_script, x) 323 return torch.jit._wait(fut) 324 325 # no future 326 error_msg = "The size.*must match the size of tensor" 327 with self.assertRaisesRegexWithHighlight(Exception, error_msg, "x.t() + x"): 328 foo(x) 329 330 # one future 331 with self.assertRaisesRegexWithHighlight( 332 Exception, error_msg, "torch.jit._fork(foo, x" 333 ): 334 wait_script(x) 335 336 # two futures with a different error 337 x = torch.rand(3, 4, 5) 338 with self.assertRaisesRegexWithHighlight( 339 Exception, 340 "expects a tensor with <= 2 dimensions", 341 "torch.jit._fork(wait_script, x", 342 ): 343 wait_script_nest(x) 344 345 def test_async_grad_guard_with_grad(self): 346 @torch.jit.script 347 def foo(x): 348 y = x * 2 349 return y.requires_grad 350 351 @torch.jit.script 352 def bar(x): 353 fut = torch.jit._fork(foo, x) 354 requires_grad_in_fork = torch.jit._wait(fut) 355 z = x * 2 356 return (requires_grad_in_fork, z.requires_grad) 357 358 x = torch.randn(3, requires_grad=True) 359 360 with torch.enable_grad(): 361 (inside_fork, after_wait) = bar(x) 362 363 self.assertEqual(inside_fork, True) 364 self.assertEqual(after_wait, True) 365 366 def test_async_grad_guard_no_grad(self): 367 @torch.jit.script 368 def foo(x): 369 y = x * 2 370 return y.requires_grad 371 372 @torch.jit.script 373 def bar(x): 374 fut = torch.jit._fork(foo, x) 375 requires_grad_in_fork = torch.jit._wait(fut) 376 z = x * 2 377 return (requires_grad_in_fork, z.requires_grad) 378 379 x = torch.randn(3, requires_grad=True) 380 381 with torch.no_grad(): 382 (inside_fork, after_wait) = bar(x) 383 384 self.assertEqual(inside_fork, False) 385 self.assertEqual(after_wait, False) 386 387 def test_trace_fork_wait(self): 388 def fork_body(x): 389 return x.neg(), x.neg() + 1 390 391 def fn(x): 392 fut = torch.jit._fork(fork_body, x) 393 vals = torch.jit._wait(fut) 394 return vals[0], vals[1], x - 1 395 396 traced = torch.jit.trace(fn, (torch.rand(3, 4),)) 397 x = torch.rand(3, 4) 398 self.assertEqual(fn(x), traced(x)) 399 400 self.assertGraphContainsExactly( 401 traced.graph, kind="prim::fork", num_kind_nodes=1 402 ) 403 self.assertGraphContainsExactly( 404 traced.graph, kind="aten::wait", num_kind_nodes=1 405 ) 406 self.assertGraphContainsExactly( 407 traced.graph, kind="aten::neg", num_kind_nodes=2, consider_subgraphs=True 408 ) 409 410 def test_trace_fork_wait_leaking(self): 411 my_list = [] 412 413 def fork_body(x): 414 my_list.append(x + 1) 415 return x + 1 416 417 def fn(x): 418 fut = torch.jit._fork(fork_body, x) 419 val = torch.jit._wait(fut) 420 return my_list[0] 421 422 with self.assertRaisesRegexWithHighlight( 423 RuntimeError, 424 "did not have observable data dependence with trace inputs; " 425 "this probably indicates your program cannot be understood " 426 "by the tracer.", 427 "", 428 ): 429 traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False) 430 431 def test_trace_fork_wait_inline(self): 432 def fork_body(x): 433 return x + 1, x + 2 434 435 def fn(x): 436 fut = torch.jit._fork(fork_body, x) 437 val = torch.jit._wait(fut) 438 return val[1] 439 440 traced = torch.jit.trace(fn, (torch.rand(3, 4),)) 441 torch._C._jit_pass_inline_fork_wait(traced.graph) 442 self.assertGraphContainsExactly( 443 traced.graph, kind="prim::fork", num_kind_nodes=0 444 ) 445 self.assertGraphContainsExactly( 446 traced.graph, kind="aten::wait", num_kind_nodes=0 447 ) 448 self.assertGraphContainsExactly( 449 traced.graph, kind="aten::add", num_kind_nodes=2 450 ) 451 452 def test_trace_fork_wait_list_modulecalls(self): 453 def add_one(input): 454 return input + torch.ones(input.size()) 455 456 class TestListFutureModule(nn.Module): 457 def forward(self, input): 458 input_list = [] 459 for i in range(3): 460 input_list.append(input) 461 462 fut_list: List[Future[torch.Tensor]] = [] 463 for input_tensor in input_list: 464 fut_list.append(torch.jit._fork(add_one, input_tensor)) 465 # return list[future[tensor]] here to ensure tracing 466 # module calls return the correct types 467 return fut_list 468 469 class TestModuleWrapper(nn.Module): 470 def __init__(self) -> None: 471 super().__init__() 472 self.list_fut_mod = TestListFutureModule() 473 474 def forward(self, input): 475 fut_list = self.list_fut_mod(input) 476 res = input 477 for fut in fut_list: 478 res = res + fut.wait() 479 return res 480 481 self.checkTrace(TestModuleWrapper(), (torch.randn(5, 5),)) 482 483 def test_trace_modulecalls_with_different_output_types(self): 484 def add_one(input): 485 return input + torch.ones(input.size()) 486 487 class DifferentOutputModule(nn.Module): 488 def forward(self, input): 489 fut_res = torch.jit._fork(add_one, (input)) 490 491 # return different types from module call 492 return input, fut_res 493 494 class TestModule(nn.Module): 495 def __init__(self) -> None: 496 super().__init__() 497 self.gen_output = DifferentOutputModule() 498 499 def forward(self, input): 500 res, fut_res = self.gen_output(input) 501 res = res + fut_res.wait() 502 return res 503 504 self.checkTrace(TestModule(), (torch.randn(5, 5),)) 505 506 def test_no_future_subtype_message(self): 507 with self.assertRaisesRegexWithHighlight( 508 RuntimeError, "Future without a contained type", "" 509 ): 510 511 @torch.jit.script 512 def forward(self, x): 513 futs = torch.jit.annotate(List[torch.jit.Future], []) 514 515 def test_future_subtyping(self): 516 """ 517 Test that futures subtype each other properly. 518 """ 519 520 # Successful subtyping. 521 def returns_int(x: int) -> int: 522 return x + x + 1 523 524 def returns_future_any(x: int) -> torch.jit.Future[Any]: 525 return torch.jit._fork(returns_int, (x)) 526 527 @torch.jit.script 528 def fn_int(x: int) -> Any: 529 fut = returns_future_any(x) 530 return fut.wait() 531 532 # Unsuccessful subtyping. 533 with self.assertRaisesRegexWithHighlight( 534 RuntimeError, 535 r"was annotated as having type Future\[float\] but is actually of type Future\[int\]", 536 "fut = returns_future_float(x", 537 ): 538 539 def returns_future_float(x: int) -> torch.jit.Future[float]: 540 return torch.jit._fork(returns_int, (x)) 541 542 @torch.jit.script 543 def fn_float(x: int) -> Any: 544 fut = returns_future_float(x) 545 return fut.wait() 546 547 548if __name__ == "__main__": 549 raise RuntimeError( 550 "This test file is not meant to be run directly, use:\n\n" 551 "\tpython test/test_jit.py TESTNAME\n\n" 552 "instead." 553 ) 554