xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3
4import threading
5
6import torch
7import torch.distributed.autograd as dist_autograd
8import torch.distributed.rpc as rpc
9from torch import optim
10from torch.distributed.optim import DistributedOptimizer
11from torch.testing._internal.dist_utils import dist_init
12from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
13    RpcAgentTestFixture,
14)
15
16
17class MyModule:
18    lock = threading.Lock()
19
20    def __init__(self, requires_grad=True):
21        # cannot directly use torch.manual_seed(0) as all threads share the same
22        # default generator. The race from multiple RPC threads could mess up
23        # the draw order from the default RNG instance, leading to
24        # non-deterministic behavior. Hence, create a dedicated RNG here.
25        g_cpu = torch.Generator()
26        g_cpu.manual_seed(0)
27        self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu)
28
29    def forward(self, t1):
30        return torch.mm(self.w, t1)
31
32    def get_w(self):
33        return self.w
34
35
36class FailingOptimizer(optim.Optimizer):
37    def __init__(self, params):
38        super().__init__(params, {})
39
40    def step(self, closure=None):
41        raise ValueError("Error running optimizer.")
42
43
44class OptimizerFailingOnConstructor(optim.Optimizer):
45    def __init__(self, params):
46        super().__init__(params, {})
47        raise ValueError("Error creating optimizer.")
48
49    def step(self, closure=None):
50        raise NotImplementedError
51
52
53def _call_method(method, obj_rref, *args, **kwargs):
54    return method(obj_rref.local_value(), *args, **kwargs)
55
56
57def remote_method(method, obj_rref, *args, **kwargs):
58    """
59    Call rpc.remote on a method in a remote object.
60
61    Args:
62        method: the method (for example, Class.method)
63        obj_rref (RRef): remote reference to the object
64        args: positional arguments to pass to the method
65        kwargs: keyword arguments to pass to the method
66
67    Returns a RRef to the remote method call result.
68    """
69    return rpc.remote(
70        obj_rref.owner(),
71        _call_method,
72        args=[method, obj_rref] + list(args),
73        kwargs=kwargs,
74    )
75
76
77def rpc_async_method(method, obj_rref, *args, **kwargs):
78    """
79    Call rpc.rpc_async on a method in a remote object.
80
81    Args:
82        method: the method (for example, Class.method)
83        obj_rref (RRef): remote reference to the object
84        args: positional arguments to pass to the method
85        kwargs: keyword arguments to pass to the method
86
87    Returns a Future to the method call result.
88    """
89    return rpc.rpc_async(
90        obj_rref.owner(),
91        _call_method,
92        args=[method, obj_rref] + list(args),
93        kwargs=kwargs,
94    )
95
96
97class DistOptimizerTest(RpcAgentTestFixture):
98    @dist_init()
99    def test_dist_optim_exception(self):
100        # distributed version
101        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
102        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
103
104        remote_module1 = rpc.remote(owner1, MyModule)
105        remote_module2 = rpc.remote(owner2, MyModule)
106        remote_param1 = remote_method(MyModule.get_w, remote_module1)
107        remote_param2 = remote_method(MyModule.get_w, remote_module2)
108
109        dist_optim = DistributedOptimizer(
110            FailingOptimizer, [remote_param1, remote_param2]
111        )
112
113        with dist_autograd.context() as context_id:
114            g_cpu = torch.Generator()
115            g_cpu.manual_seed(0)
116            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
117            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
118            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
119            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
120            loss = torch.add(output2.wait(), t1).sum()
121
122            dist_autograd.backward(context_id, [loss])
123            with self.assertRaisesRegex(Exception, "Error running optimizer"):
124                dist_optim.step(context_id)
125
126    @dist_init()
127    def test_dist_optim_exception_on_constructor(self):
128        # distributed version
129        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
130        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
131
132        remote_module1 = rpc.remote(owner1, MyModule)
133        remote_module2 = rpc.remote(owner2, MyModule)
134        remote_param1 = remote_method(MyModule.get_w, remote_module1)
135        remote_param2 = remote_method(MyModule.get_w, remote_module2)
136
137        with self.assertRaisesRegex(Exception, "Error creating optimizer."):
138            dist_optim = DistributedOptimizer(
139                OptimizerFailingOnConstructor, [remote_param1, remote_param2]
140            )
141
142    def _test_dist_optim_base(self, optim_cls, *args, **kwargs):
143        # local version
144        module1 = MyModule()
145        module2 = MyModule()
146        params = [module1.get_w(), module2.get_w()]
147        local_optim = optim_cls(params, *args, **kwargs)
148
149        old_w1 = module1.w.clone().detach()
150        old_w2 = module2.w.clone().detach()
151
152        g_cpu = torch.Generator()
153        g_cpu.manual_seed(0)
154        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
155        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
156        output1 = module1.forward(t2)
157        output2 = module2.forward(output1)
158        loss = torch.add(output2, t1).sum()
159
160        loss.backward()
161        local_optim.step()
162
163        # distributed version
164        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
165        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
166
167        remote_module1 = rpc.remote(owner1, MyModule)
168        remote_module2 = rpc.remote(owner2, MyModule)
169        remote_param1 = remote_method(MyModule.get_w, remote_module1)
170        remote_param2 = remote_method(MyModule.get_w, remote_module2)
171
172        old_w1_remote = remote_param1.to_here()
173
174        # sanity check: local and remote initial weights should match
175        self.assertEqual(old_w1, remote_param1.to_here())
176        self.assertEqual(old_w2, remote_param2.to_here())
177
178        dist_optim = DistributedOptimizer(
179            optim_cls, [remote_param1, remote_param2], *args, **kwargs
180        )
181
182        with dist_autograd.context() as context_id:
183            g_cpu.manual_seed(0)
184            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
185            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
186            output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
187            output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
188            loss = torch.add(output2.wait(), t1)
189
190            dist_autograd.backward(context_id, [loss.sum()])
191            dist_optim.step(context_id)
192
193            new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait()
194            new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait()
195
196            # ensure optimizer changed weights
197            self.assertNotEqual(old_w1, new_w1)
198            self.assertNotEqual(old_w2, new_w2)
199            # ensure local equals remote
200            self.assertEqual(new_w1, module1.get_w())
201            self.assertEqual(new_w2, module2.get_w())
202
203    @dist_init()
204    def test_dist_optim(self):
205        self._test_dist_optim_base(optim.Adagrad, lr=0.05)
206        self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True)
207        self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True)
208        self._test_dist_optim_base(optim.SGD, lr=0.05)
209        self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True)
210        self._test_dist_optim_base(optim.Adadelta, rho=0.95)
211        self._test_dist_optim_base(optim.RMSprop, lr=0.05)
212        self._test_dist_optim_base(optim.Adamax, lr=0.05)
213        self._test_dist_optim_base(optim.Rprop, lr=0.05)
214
215    def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs):
216        # local version
217        module1 = MyModule()
218        module2 = MyModule(requires_grad=False)
219        params = [module1.get_w(), module2.get_w()]
220        local_optim = optim_cls(params, *args, **kwargs)
221
222        old_w1 = module1.w.clone().detach()
223        old_w2 = module2.w.clone().detach()
224
225        g_cpu = torch.Generator()
226        g_cpu.manual_seed(0)
227        t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
228        t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
229        output1 = module1.forward(t2)
230        output2 = module2.forward(output1)
231        loss = torch.add(output2, t1).sum()
232
233        loss.backward()
234        local_optim.step()
235
236        # distributed version
237        owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
238        owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
239
240        remote_module1 = rpc.remote(owner1, MyModule)
241        remote_module2 = rpc.remote(owner2, MyModule, args=(False,))
242        remote_param1 = remote_module1.remote().get_w()
243        remote_param2 = remote_module2.remote().get_w()
244
245        # sanity check: local and remote initial weights should match
246        self.assertEqual(old_w1, remote_param1.to_here())
247        self.assertEqual(old_w2, remote_param2.to_here())
248
249        dist_optim = DistributedOptimizer(
250            optim_cls, [remote_param1, remote_param2], *args, **kwargs
251        )
252
253        with dist_autograd.context() as context_id:
254            g_cpu.manual_seed(0)
255            t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
256            t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu)
257            output1 = remote_module1.rpc_async().forward(t2)
258            output2 = remote_module2.rpc_async().forward(output1.wait())
259            loss = torch.add(output2.wait(), t1)
260
261            dist_autograd.backward(context_id, [loss.sum()])
262            dist_optim.step(context_id)
263
264            new_w1 = remote_module1.rpc_async().get_w().wait()
265            new_w2 = remote_module2.rpc_async().get_w().wait()
266
267            # ensure optimizer changed weights for w1
268            self.assertNotEqual(old_w1, new_w1)
269
270            # ensure optimizer not changed weights for w2
271            self.assertEqual(old_w2, new_w2)
272            # ensure local equals remote
273            self.assertEqual(new_w1, module1.get_w())
274            self.assertEqual(new_w2, module2.get_w())
275
276    @dist_init()
277    def test_dist_optim_none_grads(self):
278        self._test_dist_optim_none_grads(optim.SGD, lr=0.05)
279        self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05)
280        self._test_dist_optim_none_grads(optim.Rprop, lr=0.05)
281        self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95)
282