xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3from typing import Dict, Tuple
4
5import torch
6import torch.distributed.autograd as dist_autograd
7import torch.distributed.rpc as rpc
8from torch import Tensor
9from torch.distributed.rpc import rpc_async
10from torch.testing import FileCheck
11from torch.testing._internal.dist_utils import dist_init, worker_name
12from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
13    RpcAgentTestFixture,
14)
15
16
17@torch.jit.script
18def local_add(t1, t2):
19    return torch.add(t1, t2)
20
21
22@torch.jit.script
23def remote_add(t1, t2, dst: str):  # noqa: E999
24    return rpc_async(dst, local_add, (t1, t2)).wait()
25
26
27@torch.jit.script
28def fork_add(t1, t2, dst: str):
29    fut = torch.jit._fork(remote_add, t1, t2, dst)
30    return torch.jit._wait(fut)
31
32
33class JitDistAutogradTest(RpcAgentTestFixture):
34    @dist_init
35    def test_get_gradients(self):
36        dst_rank = self.rank
37
38        @torch.jit.script
39        def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]):
40            return dist_autograd.get_gradients(context_id)
41
42        FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
43        with dist_autograd.context() as context_id:
44            t1 = torch.rand((3, 3), requires_grad=True)
45            t2 = torch.rand((3, 3), requires_grad=True)
46            t3 = torch.add(t1, t2)
47
48            dist_autograd.backward(context_id, [t3.sum()])
49            grads = dist_get_gradients(context_id)
50
51            self.assertEqual(2, len(grads))
52            self.assertIn(t1, grads)
53            self.assertIn(t2, grads)
54            self.assertEqual(torch.ones(3, 3), grads[t1])
55            self.assertEqual(torch.ones(3, 3), grads[t2])
56
57    @dist_init
58    def test_dist_backward(self):
59        if self.rank != 0:
60            return
61
62        @torch.jit.script
63        def dist_backward_script(context_id: int, loss: torch.Tensor):
64            dist_autograd.backward(context_id, [loss])
65
66        FileCheck().check("dist_backward").run(str(dist_backward_script.graph))
67        with dist_autograd.context() as context_id:
68            t1 = torch.rand(3, 3, requires_grad=True)
69            t2 = torch.rand(3, 3, requires_grad=True)
70            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
71            loss = rpc.rpc_sync(dst_worker_name, torch.add, args=(t1, t2)).sum()
72            dist_backward_script(context_id, loss)
73
74    @dist_init
75    def test_jit_fork_within_context(self):
76        with dist_autograd.context() as context_id:
77            t1 = torch.rand((3, 3), requires_grad=True)
78            t2 = torch.rand((3, 3), requires_grad=True)
79            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
80            res = fork_add(t1, t2, dst_worker_name)
81            loss = res.sum()
82            dist_autograd.backward(context_id, [loss])
83
84            grads = dist_autograd.get_gradients(context_id)
85            self.assertEqual(2, len(grads))
86            self.assertIn(t1, grads)
87            self.assertIn(t2, grads)
88
89    @dist_init
90    def test_restore_context_after_swtich_to_jit_thread(self):
91        if self.rank != 0:
92            return
93
94        @torch.jit.script
95        def forward_script(
96            context_id: int, dst_worker_name: str, t1: Tensor, t2: Tensor
97        ) -> Tuple[Tensor, Tensor]:
98            res1_fut = rpc.rpc_async(dst_worker_name, local_add, (t1, t1))
99            res1 = res1_fut.wait()  # After this, the script runs in a new JIT thread.
100            loss1 = res1.sum()
101
102            # SendRpcBackward is not attached, since DistAutogradContext is lost here.
103            res2_fut = rpc.rpc_async(dst_worker_name, local_add, (t2, t2))
104            res2 = res2_fut.wait()
105            loss2 = res2.sum()
106
107            return loss1, loss2
108
109        with dist_autograd.context() as context_id:
110            t1 = torch.ones((2, 3), requires_grad=True)
111            t2 = torch.ones((2, 3), requires_grad=True)
112            dst_worker_name = worker_name((self.rank + 1) % self.world_size)
113            loss0, loss1 = forward_script(context_id, dst_worker_name, t1, t2)
114            dist_autograd.backward(context_id, [loss0, loss1])
115            grad0, grad1 = dist_autograd.get_gradients(context_id)
116            self.assertEqual(grad0, grad1)
117