1# mypy: allow-untyped-defs 2 3# If you need to modify this file to make this test pass, please also apply same edits accordingly to 4# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py 5# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server 6 7import threading 8from datetime import datetime 9from time import perf_counter 10 11import torch 12import torch.distributed.rpc as rpc 13import torch.nn as nn 14from torch import optim 15 16from torch.testing._internal.dist_utils import ( 17 dist_init, 18 worker_name, 19) 20from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture 21 22batch_size = 20 23in_features = 100 24out_features = 30 25num_batches = 4 26 27 28def timed_log(text): 29 print(f"{datetime.now().strftime('%H:%M:%S')} {text}") 30 31 32class BatchUpdateParameterServer: 33 34 def __init__(self, batch_update_size): 35 self.model = nn.Linear(in_features, out_features) 36 self.lock = threading.Lock() 37 self.future_model = torch.futures.Future() 38 self.batch_update_size = batch_update_size 39 self.curr_update_size = 0 40 self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9) 41 for p in self.model.parameters(): 42 p.grad = torch.zeros_like(p) 43 44 def get_model(self): 45 return self.model 46 47 @staticmethod 48 @rpc.functions.async_execution 49 def update_and_fetch_model(ps_rref, grads): 50 self = ps_rref.local_value() 51 for p, g in zip(self.model.parameters(), grads): 52 if p.grad is None: 53 p.grad = g 54 else: 55 p.grad += g 56 with self.lock: 57 timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates") 58 self.curr_update_size += 1 59 fut = self.future_model 60 61 if self.curr_update_size >= self.batch_update_size: 62 for p in self.model.parameters(): 63 p.grad /= self.batch_update_size 64 self.curr_update_size = 0 65 self.optimizer.step() 66 self.optimizer.zero_grad() 67 fut.set_result(self.model) 68 timed_log("PS updated model") 69 self.future_model = torch.futures.Future() 70 71 return fut 72 73 74class Trainer: 75 76 def __init__(self, ps_rref): 77 self.ps_rref = ps_rref 78 self.loss_fn = nn.L1Loss() 79 80 def get_next_batch(self): 81 for _ in range(num_batches): 82 inputs = torch.randn(batch_size, in_features) 83 labels = torch.zeros(batch_size, out_features) 84 yield inputs, labels 85 86 def train(self): 87 name = rpc.get_worker_info().name 88 m = self.ps_rref.rpc_sync().get_model() 89 for inputs, labels in self.get_next_batch(): 90 timed_log(f"{name} processing one batch") 91 self.loss_fn(m(inputs), labels).backward() 92 timed_log(f"{name} reporting grads") 93 m = rpc.rpc_sync( 94 self.ps_rref.owner(), 95 BatchUpdateParameterServer.update_and_fetch_model, 96 args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]), 97 ) 98 timed_log(f"{name} got updated model") 99 100 101def run_trainer(ps_rref): 102 trainer = Trainer(ps_rref) 103 trainer.train() 104 105 106def run_ps(trainers): 107 timed_log("Start training") 108 start = perf_counter() 109 ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers))) 110 futs = [] 111 for trainer in trainers: 112 futs.append( 113 rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) 114 ) 115 116 torch.futures.wait_all(futs) 117 stop = perf_counter() 118 timed_log("Finish training") 119 timed_log(f"Time spent training: {stop-start}s") 120 121class ParameterServerTest(RpcAgentTestFixture): 122 123 @dist_init(setup_rpc=False) 124 def test_batch_updating_parameter_server(self): 125 126 if self.rank != 0: 127 rpc.init_rpc( 128 name=worker_name(self.rank), 129 backend=self.rpc_backend, 130 rank=self.rank, 131 world_size=self.world_size, 132 rpc_backend_options=self.rpc_backend_options, 133 ) 134 else: 135 rpc.init_rpc( 136 name=worker_name(self.rank), 137 backend=self.rpc_backend, 138 rank=self.rank, 139 world_size=self.world_size, 140 rpc_backend_options=self.rpc_backend_options, 141 ) 142 run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)]) 143 144 rpc.shutdown() 145