xref: /aosp_15_r20/external/pytorch/test/distributed/elastic/test_control_plane.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: distributed"]
3
4import json
5import os
6import pickle
7import socket
8import tempfile
9from contextlib import contextmanager
10from typing import Dict
11
12from urllib3.connection import HTTPConnection
13from urllib3.connectionpool import HTTPConnectionPool
14
15from torch.distributed.elastic.control_plane import (
16    TORCH_WORKER_SERVER_SOCKET,
17    worker_main,
18)
19from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
20
21
22class UnixHTTPConnection(HTTPConnection):
23    def __init__(self, socket_path: str) -> None:
24        super().__init__("localhost")
25
26        self.socket_path = socket_path
27
28    def connect(self) -> None:
29        self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
30        self.sock.connect(self.socket_path)
31
32
33class UnixHTTPConnectionPool(HTTPConnectionPool):
34    def __init__(self, socket_path: str) -> None:
35        super().__init__("localhost")
36
37        self.socket_path = socket_path
38
39    def _new_conn(self):
40        return UnixHTTPConnection(self.socket_path)
41
42
43@contextmanager
44def local_worker_server() -> None:
45    with tempfile.TemporaryDirectory() as tmpdir:
46        socket_path = os.path.join(tmpdir, "socket.sock")
47        os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path
48
49        with worker_main():
50            pool = UnixHTTPConnectionPool(socket_path)
51            yield pool
52
53
54class WorkerServerTest(TestCase):
55    def test_worker_server(self) -> None:
56        with local_worker_server() as pool:
57            resp = pool.request("GET", "/")
58            self.assertEqual(resp.status, 200)
59            self.assertEqual(
60                resp.data,
61                b"""<h1>torch.distributed.WorkerServer</h1>
62<a href="/handler/">Handler names</a>
63""",
64            )
65
66            resp = pool.request("POST", "/handler/ping")
67            self.assertEqual(resp.status, 200)
68            self.assertEqual(resp.data, b"pong")
69
70            resp = pool.request("GET", "/handler/")
71            self.assertEqual(resp.status, 200)
72            self.assertIn("ping", json.loads(resp.data))
73
74            resp = pool.request("POST", "/handler/nonexistant")
75            self.assertEqual(resp.status, 404)
76            self.assertIn(b"Handler nonexistant not found:", resp.data)
77
78    @requires_cuda
79    def test_dump_nccl_trace_pickle(self) -> None:
80        with local_worker_server() as pool:
81            resp = pool.request("POST", "/handler/dump_nccl_trace_pickle")
82            self.assertEqual(resp.status, 200)
83            out = pickle.loads(resp.data)
84            self.assertIsInstance(out, dict)
85            self.assertIn("version", out)
86
87    @requires_cuda
88    def test_dump_nccl_trace_pickle_with_params(self) -> None:
89        with local_worker_server() as pool:
90            # bad key - not lower case
91            resp = pool.request(
92                "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true"
93            )
94            self.assertEqual(resp.status, 400)
95            # unknown key
96            resp = pool.request(
97                "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true"
98            )
99            self.assertEqual(resp.status, 400)
100            # bad value - not a bool
101            resp = pool.request(
102                "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool"
103            )
104            self.assertEqual(resp.status, 400)
105            # bad value - value not lowercase
106            resp = pool.request(
107                "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True"
108            )
109            self.assertEqual(resp.status, 400)
110            # good key and value
111            resp = pool.request(
112                "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true"
113            )
114            self.assertEqual(resp.status, 200)
115            # multiple good keys and values
116            resp = pool.request(
117                "POST",
118                "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true",
119            )
120            self.assertEqual(resp.status, 200)
121
122    @requires_cuda
123    def test_dump_nccl_trace_pickle_with_json(self) -> None:
124        with local_worker_server() as pool:
125            # bad key - not lower case
126            resp = pool.request(
127                "POST", "/handler/dump_nccl_trace_json?includeCollectives=true"
128            )
129            self.assertEqual(resp.status, 400)
130            # unknown key
131            resp = pool.request("POST", "/handler/dump_nccl_trace_json?unknownkey=true")
132            self.assertEqual(resp.status, 400)
133            # bad value - not a bool
134            resp = pool.request(
135                "POST", "/handler/dump_nccl_trace_json?includecollectives=notabool"
136            )
137            self.assertEqual(resp.status, 400)
138            # bad value - value not lowercase
139            resp = pool.request(
140                "POST", "/handler/dump_nccl_trace_json?includecollectives=True"
141            )
142            self.assertEqual(resp.status, 400)
143            # good key and value
144            resp = pool.request(
145                "POST", "/handler/dump_nccl_trace_json?includecollectives=true"
146            )
147            self.assertEqual(resp.status, 200)
148            # multiple good keys and values
149            resp = pool.request(
150                "POST",
151                "/handler/dump_nccl_trace_json?includecollectives=true&onlyactive=true",
152            )
153            self.assertEqual(resp.status, 200)
154
155    def test_tcp(self) -> None:
156        import requests
157
158        from torch._C._distributed_c10d import _WorkerServer
159
160        server = _WorkerServer("", 1234)
161        out = requests.get("http://localhost:1234/handler/")
162        self.assertEqual(out.status_code, 200)
163
164        server.shutdown()
165
166    def test_dump_traceback(self) -> None:
167        with local_worker_server() as pool:
168            resp = pool.request("POST", "/handler/dump_traceback")
169            self.assertEqual(resp.status, 200)
170            self.assertIn(b"in test_dump_traceback\n", resp.data)
171
172    def test_run_handler(self) -> None:
173        from torch._C._distributed_c10d import _get_handler, _Request, _Response
174
175        handler = _get_handler("ping")
176
177        class Request(_Request):
178            def __init__(self) -> None:
179                _Request.__init__(self)
180
181            def body(self) -> bytes:
182                return b"dummy"
183
184            def params(self) -> Dict[str, str]:
185                return {}
186
187        class Response(_Response):
188            def __init__(self) -> None:
189                _Response.__init__(self)
190
191            def set_content(self, content: str, content_type: str) -> None:
192                self.content = content
193                self.content_type = content_type
194
195            def set_status(self, status: int) -> None:
196                self.status = status
197
198        req = Request()
199        resp = Response()
200
201        handler(req, resp)
202
203        self.assertEqual(resp.status, 200)
204        self.assertEqual(resp.content, "pong")
205        self.assertEqual(resp.content_type, "text/plain")
206
207    def test_get_handler_nonexistant(self) -> None:
208        from torch._C._distributed_c10d import _get_handler
209
210        with self.assertRaisesRegex(ValueError, "Failed to find handler nonexistant"):
211            _get_handler("nonexistant")
212
213    def test_get_handler_names(self) -> None:
214        from torch._C._distributed_c10d import _get_handler_names
215
216        names = _get_handler_names()
217        self.assertIn("ping", names)
218
219
220if __name__ == "__main__":
221    run_tests()
222