1# mypy: allow-untyped-defs 2 3from typing import Dict, Tuple 4 5import torch 6import torch.distributed.autograd as dist_autograd 7import torch.distributed.rpc as rpc 8from torch import Tensor 9from torch.distributed.rpc import rpc_async 10from torch.testing import FileCheck 11from torch.testing._internal.dist_utils import dist_init, worker_name 12from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 13 RpcAgentTestFixture, 14) 15 16 17@torch.jit.script 18def local_add(t1, t2): 19 return torch.add(t1, t2) 20 21 22@torch.jit.script 23def remote_add(t1, t2, dst: str): # noqa: E999 24 return rpc_async(dst, local_add, (t1, t2)).wait() 25 26 27@torch.jit.script 28def fork_add(t1, t2, dst: str): 29 fut = torch.jit._fork(remote_add, t1, t2, dst) 30 return torch.jit._wait(fut) 31 32 33class JitDistAutogradTest(RpcAgentTestFixture): 34 @dist_init 35 def test_get_gradients(self): 36 dst_rank = self.rank 37 38 @torch.jit.script 39 def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]): 40 return dist_autograd.get_gradients(context_id) 41 42 FileCheck().check("get_gradients").run(str(dist_get_gradients.graph)) 43 with dist_autograd.context() as context_id: 44 t1 = torch.rand((3, 3), requires_grad=True) 45 t2 = torch.rand((3, 3), requires_grad=True) 46 t3 = torch.add(t1, t2) 47 48 dist_autograd.backward(context_id, [t3.sum()]) 49 grads = dist_get_gradients(context_id) 50 51 self.assertEqual(2, len(grads)) 52 self.assertIn(t1, grads) 53 self.assertIn(t2, grads) 54 self.assertEqual(torch.ones(3, 3), grads[t1]) 55 self.assertEqual(torch.ones(3, 3), grads[t2]) 56 57 @dist_init 58 def test_dist_backward(self): 59 if self.rank != 0: 60 return 61 62 @torch.jit.script 63 def dist_backward_script(context_id: int, loss: torch.Tensor): 64 dist_autograd.backward(context_id, [loss]) 65 66 FileCheck().check("dist_backward").run(str(dist_backward_script.graph)) 67 with dist_autograd.context() as context_id: 68 t1 = torch.rand(3, 3, requires_grad=True) 69 t2 = torch.rand(3, 3, requires_grad=True) 70 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 71 loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum() 72 dist_backward_script(context_id, loss) 73 74 @dist_init 75 def test_jit_fork_within_context(self): 76 with dist_autograd.context() as context_id: 77 t1 = torch.rand((3, 3), requires_grad=True) 78 t2 = torch.rand((3, 3), requires_grad=True) 79 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 80 res = fork_add(t1, t2, dst_worker_name) 81 loss = res.sum() 82 dist_autograd.backward(context_id, [loss]) 83 84 grads = dist_autograd.get_gradients(context_id) 85 self.assertEqual(2, len(grads)) 86 self.assertIn(t1, grads) 87 self.assertIn(t2, grads) 88 89 @dist_init 90 def test_restore_context_after_swtich_to_jit_thread(self): 91 if self.rank != 0: 92 return 93 94 @torch.jit.script 95 def forward_script( 96 context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor 97 ) -> Tuple[Tensor, Tensor]: 98 res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1)) 99 res1 = res1_fut.wait() # After this, the script runs in a new JIT thread. 100 loss1 = res1.sum() 101 102 # SendRpcBackward is not attached, since DistAutogradContext is lost here. 103 res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2)) 104 res2 = res2_fut.wait() 105 loss2 = res2.sum() 106 107 return loss1, loss2 108 109 with dist_autograd.context() as context_id: 110 t1 = torch.ones((2, 3), requires_grad=True) 111 t2 = torch.ones((2, 3), requires_grad=True) 112 dst_worker_name = worker_name((self.rank + 1) % self.world_size) 113 loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2) 114 dist_autograd.backward(context_id, [loss0, loss1]) 115 grad0, grad1 = dist_autograd.get_gradients(context_id) 116 self.assertEqual(grad0, grad1) 117