1# mypy: allow-untyped-defs 2 3 4import threading 5 6import torch 7import torch.distributed.autograd as dist_autograd 8import torch.distributed.rpc as rpc 9from torch import optim 10from torch.distributed.optim import DistributedOptimizer 11from torch.testing._internal.dist_utils import dist_init 12from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( 13 RpcAgentTestFixture, 14) 15 16 17class MyModule: 18 lock = threading.Lock() 19 20 def __init__(self, requires_grad=True): 21 # cannot directly use torch.manual_seed(0) as all threads share the same 22 # default generator. The race from multiple RPC threads could mess up 23 # the draw order from the default RNG instance, leading to 24 # non-deterministic behavior. Hence, create a dedicated RNG here. 25 g_cpu = torch.Generator() 26 g_cpu.manual_seed(0) 27 self.w = torch.rand((3, 3), requires_grad=requires_grad, generator=g_cpu) 28 29 def forward(self, t1): 30 return torch.mm(self.w, t1) 31 32 def get_w(self): 33 return self.w 34 35 36class FailingOptimizer(optim.Optimizer): 37 def __init__(self, params): 38 super().__init__(params, {}) 39 40 def step(self, closure=None): 41 raise ValueError("Error running optimizer.") 42 43 44class OptimizerFailingOnConstructor(optim.Optimizer): 45 def __init__(self, params): 46 super().__init__(params, {}) 47 raise ValueError("Error creating optimizer.") 48 49 def step(self, closure=None): 50 raise NotImplementedError 51 52 53def _call_method(method, obj_rref, *args, **kwargs): 54 return method(obj_rref.local_value(), *args, **kwargs) 55 56 57def remote_method(method, obj_rref, *args, **kwargs): 58 """ 59 Call rpc.remote on a method in a remote object. 60 61 Args: 62 method: the method (for example, Class.method) 63 obj_rref (RRef): remote reference to the object 64 args: positional arguments to pass to the method 65 kwargs: keyword arguments to pass to the method 66 67 Returns a RRef to the remote method call result. 68 """ 69 return rpc.remote( 70 obj_rref.owner(), 71 _call_method, 72 args=[method, obj_rref] + list(args), 73 kwargs=kwargs, 74 ) 75 76 77def rpc_async_method(method, obj_rref, *args, **kwargs): 78 """ 79 Call rpc.rpc_async on a method in a remote object. 80 81 Args: 82 method: the method (for example, Class.method) 83 obj_rref (RRef): remote reference to the object 84 args: positional arguments to pass to the method 85 kwargs: keyword arguments to pass to the method 86 87 Returns a Future to the method call result. 88 """ 89 return rpc.rpc_async( 90 obj_rref.owner(), 91 _call_method, 92 args=[method, obj_rref] + list(args), 93 kwargs=kwargs, 94 ) 95 96 97class DistOptimizerTest(RpcAgentTestFixture): 98 @dist_init() 99 def test_dist_optim_exception(self): 100 # distributed version 101 owner1 = "worker%d" % ((self.rank + 1) % self.world_size) 102 owner2 = "worker%d" % ((self.rank + 2) % self.world_size) 103 104 remote_module1 = rpc.remote(owner1, MyModule) 105 remote_module2 = rpc.remote(owner2, MyModule) 106 remote_param1 = remote_method(MyModule.get_w, remote_module1) 107 remote_param2 = remote_method(MyModule.get_w, remote_module2) 108 109 dist_optim = DistributedOptimizer( 110 FailingOptimizer, [remote_param1, remote_param2] 111 ) 112 113 with dist_autograd.context() as context_id: 114 g_cpu = torch.Generator() 115 g_cpu.manual_seed(0) 116 t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 117 t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 118 output1 = rpc_async_method(MyModule.forward, remote_module1, t2) 119 output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) 120 loss = torch.add(output2.wait(), t1).sum() 121 122 dist_autograd.backward(context_id, [loss]) 123 with self.assertRaisesRegex(Exception, "Error running optimizer"): 124 dist_optim.step(context_id) 125 126 @dist_init() 127 def test_dist_optim_exception_on_constructor(self): 128 # distributed version 129 owner1 = "worker%d" % ((self.rank + 1) % self.world_size) 130 owner2 = "worker%d" % ((self.rank + 2) % self.world_size) 131 132 remote_module1 = rpc.remote(owner1, MyModule) 133 remote_module2 = rpc.remote(owner2, MyModule) 134 remote_param1 = remote_method(MyModule.get_w, remote_module1) 135 remote_param2 = remote_method(MyModule.get_w, remote_module2) 136 137 with self.assertRaisesRegex(Exception, "Error creating optimizer."): 138 dist_optim = DistributedOptimizer( 139 OptimizerFailingOnConstructor, [remote_param1, remote_param2] 140 ) 141 142 def _test_dist_optim_base(self, optim_cls, *args, **kwargs): 143 # local version 144 module1 = MyModule() 145 module2 = MyModule() 146 params = [module1.get_w(), module2.get_w()] 147 local_optim = optim_cls(params, *args, **kwargs) 148 149 old_w1 = module1.w.clone().detach() 150 old_w2 = module2.w.clone().detach() 151 152 g_cpu = torch.Generator() 153 g_cpu.manual_seed(0) 154 t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 155 t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 156 output1 = module1.forward(t2) 157 output2 = module2.forward(output1) 158 loss = torch.add(output2, t1).sum() 159 160 loss.backward() 161 local_optim.step() 162 163 # distributed version 164 owner1 = "worker%d" % ((self.rank + 1) % self.world_size) 165 owner2 = "worker%d" % ((self.rank + 2) % self.world_size) 166 167 remote_module1 = rpc.remote(owner1, MyModule) 168 remote_module2 = rpc.remote(owner2, MyModule) 169 remote_param1 = remote_method(MyModule.get_w, remote_module1) 170 remote_param2 = remote_method(MyModule.get_w, remote_module2) 171 172 old_w1_remote = remote_param1.to_here() 173 174 # sanity check: local and remote initial weights should match 175 self.assertEqual(old_w1, remote_param1.to_here()) 176 self.assertEqual(old_w2, remote_param2.to_here()) 177 178 dist_optim = DistributedOptimizer( 179 optim_cls, [remote_param1, remote_param2], *args, **kwargs 180 ) 181 182 with dist_autograd.context() as context_id: 183 g_cpu.manual_seed(0) 184 t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 185 t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 186 output1 = rpc_async_method(MyModule.forward, remote_module1, t2) 187 output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait()) 188 loss = torch.add(output2.wait(), t1) 189 190 dist_autograd.backward(context_id, [loss.sum()]) 191 dist_optim.step(context_id) 192 193 new_w1 = rpc_async_method(MyModule.get_w, remote_module1).wait() 194 new_w2 = rpc_async_method(MyModule.get_w, remote_module2).wait() 195 196 # ensure optimizer changed weights 197 self.assertNotEqual(old_w1, new_w1) 198 self.assertNotEqual(old_w2, new_w2) 199 # ensure local equals remote 200 self.assertEqual(new_w1, module1.get_w()) 201 self.assertEqual(new_w2, module2.get_w()) 202 203 @dist_init() 204 def test_dist_optim(self): 205 self._test_dist_optim_base(optim.Adagrad, lr=0.05) 206 self._test_dist_optim_base(optim.Adam, lr=1e-2, amsgrad=True) 207 self._test_dist_optim_base(optim.AdamW, lr=0.05, amsgrad=True) 208 self._test_dist_optim_base(optim.SGD, lr=0.05) 209 self._test_dist_optim_base(optim.SGD, lr=1e-3, momentum=1, weight_decay=1, nesterov=True) 210 self._test_dist_optim_base(optim.Adadelta, rho=0.95) 211 self._test_dist_optim_base(optim.RMSprop, lr=0.05) 212 self._test_dist_optim_base(optim.Adamax, lr=0.05) 213 self._test_dist_optim_base(optim.Rprop, lr=0.05) 214 215 def _test_dist_optim_none_grads(self, optim_cls, *args, **kwargs): 216 # local version 217 module1 = MyModule() 218 module2 = MyModule(requires_grad=False) 219 params = [module1.get_w(), module2.get_w()] 220 local_optim = optim_cls(params, *args, **kwargs) 221 222 old_w1 = module1.w.clone().detach() 223 old_w2 = module2.w.clone().detach() 224 225 g_cpu = torch.Generator() 226 g_cpu.manual_seed(0) 227 t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 228 t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 229 output1 = module1.forward(t2) 230 output2 = module2.forward(output1) 231 loss = torch.add(output2, t1).sum() 232 233 loss.backward() 234 local_optim.step() 235 236 # distributed version 237 owner1 = "worker%d" % ((self.rank + 1) % self.world_size) 238 owner2 = "worker%d" % ((self.rank + 2) % self.world_size) 239 240 remote_module1 = rpc.remote(owner1, MyModule) 241 remote_module2 = rpc.remote(owner2, MyModule, args=(False,)) 242 remote_param1 = remote_module1.remote().get_w() 243 remote_param2 = remote_module2.remote().get_w() 244 245 # sanity check: local and remote initial weights should match 246 self.assertEqual(old_w1, remote_param1.to_here()) 247 self.assertEqual(old_w2, remote_param2.to_here()) 248 249 dist_optim = DistributedOptimizer( 250 optim_cls, [remote_param1, remote_param2], *args, **kwargs 251 ) 252 253 with dist_autograd.context() as context_id: 254 g_cpu.manual_seed(0) 255 t1 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 256 t2 = torch.rand((3, 3), requires_grad=True, generator=g_cpu) 257 output1 = remote_module1.rpc_async().forward(t2) 258 output2 = remote_module2.rpc_async().forward(output1.wait()) 259 loss = torch.add(output2.wait(), t1) 260 261 dist_autograd.backward(context_id, [loss.sum()]) 262 dist_optim.step(context_id) 263 264 new_w1 = remote_module1.rpc_async().get_w().wait() 265 new_w2 = remote_module2.rpc_async().get_w().wait() 266 267 # ensure optimizer changed weights for w1 268 self.assertNotEqual(old_w1, new_w1) 269 270 # ensure optimizer not changed weights for w2 271 self.assertEqual(old_w2, new_w2) 272 # ensure local equals remote 273 self.assertEqual(new_w1, module1.get_w()) 274 self.assertEqual(new_w2, module2.get_w()) 275 276 @dist_init() 277 def test_dist_optim_none_grads(self): 278 self._test_dist_optim_none_grads(optim.SGD, lr=0.05) 279 self._test_dist_optim_none_grads(optim.RMSprop, lr=0.05) 280 self._test_dist_optim_none_grads(optim.Rprop, lr=0.05) 281 self._test_dist_optim_none_grads(optim.Adadelta, rho=0.95) 282