xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/rpc/examples/parameter_server_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3# If you need to modify this file to make this test pass, please also apply same edits accordingly to
4# https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py
5# and https://pytorch.org/tutorials/intermediate/rpc_async_execution.html#batch-updating-parameter-server
6
7import threading
8from datetime import datetime
9from time import perf_counter
10
11import torch
12import torch.distributed.rpc as rpc
13import torch.nn as nn
14from torch import optim
15
16from torch.testing._internal.dist_utils import (
17    dist_init,
18    worker_name,
19)
20from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
21
22batch_size = 20
23in_features = 100
24out_features = 30
25num_batches = 4
26
27
28def timed_log(text):
29    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")
30
31
32class BatchUpdateParameterServer:
33
34    def __init__(self, batch_update_size):
35        self.model = nn.Linear(in_features, out_features)
36        self.lock = threading.Lock()
37        self.future_model = torch.futures.Future()
38        self.batch_update_size = batch_update_size
39        self.curr_update_size = 0
40        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
41        for p in self.model.parameters():
42            p.grad = torch.zeros_like(p)
43
44    def get_model(self):
45        return self.model
46
47    @staticmethod
48    @rpc.functions.async_execution
49    def update_and_fetch_model(ps_rref, grads):
50        self = ps_rref.local_value()
51        for p, g in zip(self.model.parameters(), grads):
52            if p.grad is None:
53                p.grad = g
54            else:
55                p.grad += g
56        with self.lock:
57            timed_log(f"PS got {self.curr_update_size}/{self.batch_update_size} updates")
58            self.curr_update_size += 1
59            fut = self.future_model
60
61            if self.curr_update_size >= self.batch_update_size:
62                for p in self.model.parameters():
63                    p.grad /= self.batch_update_size
64                self.curr_update_size = 0
65                self.optimizer.step()
66                self.optimizer.zero_grad()
67                fut.set_result(self.model)
68                timed_log("PS updated model")
69                self.future_model = torch.futures.Future()
70
71        return fut
72
73
74class Trainer:
75
76    def __init__(self, ps_rref):
77        self.ps_rref = ps_rref
78        self.loss_fn = nn.L1Loss()
79
80    def get_next_batch(self):
81        for _ in range(num_batches):
82            inputs = torch.randn(batch_size, in_features)
83            labels = torch.zeros(batch_size, out_features)
84            yield inputs, labels
85
86    def train(self):
87        name = rpc.get_worker_info().name
88        m = self.ps_rref.rpc_sync().get_model()
89        for inputs, labels in self.get_next_batch():
90            timed_log(f"{name} processing one batch")
91            self.loss_fn(m(inputs), labels).backward()
92            timed_log(f"{name} reporting grads")
93            m = rpc.rpc_sync(
94                self.ps_rref.owner(),
95                BatchUpdateParameterServer.update_and_fetch_model,
96                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),
97            )
98            timed_log(f"{name} got updated model")
99
100
101def run_trainer(ps_rref):
102    trainer = Trainer(ps_rref)
103    trainer.train()
104
105
106def run_ps(trainers):
107    timed_log("Start training")
108    start = perf_counter()
109    ps_rref = rpc.RRef(BatchUpdateParameterServer(len(trainers)))
110    futs = []
111    for trainer in trainers:
112        futs.append(
113            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,))
114        )
115
116    torch.futures.wait_all(futs)
117    stop = perf_counter()
118    timed_log("Finish training")
119    timed_log(f"Time spent training: {stop-start}s")
120
121class ParameterServerTest(RpcAgentTestFixture):
122
123    @dist_init(setup_rpc=False)
124    def test_batch_updating_parameter_server(self):
125
126        if self.rank != 0:
127            rpc.init_rpc(
128                name=worker_name(self.rank),
129                backend=self.rpc_backend,
130                rank=self.rank,
131                world_size=self.world_size,
132                rpc_backend_options=self.rpc_backend_options,
133            )
134        else:
135            rpc.init_rpc(
136                name=worker_name(self.rank),
137                backend=self.rpc_backend,
138                rank=self.rank,
139                world_size=self.world_size,
140                rpc_backend_options=self.rpc_backend_options,
141            )
142            run_ps([f"{worker_name(r)}" for r in range(1, self.world_size)])
143
144        rpc.shutdown()
145