1#!/usr/bin/env python3 2# Owner(s): ["oncall: distributed"] 3 4import json 5import os 6import pickle 7import socket 8import tempfile 9from contextlib import contextmanager 10from typing import Dict 11 12from urllib3.connection import HTTPConnection 13from urllib3.connectionpool import HTTPConnectionPool 14 15from torch.distributed.elastic.control_plane import ( 16 TORCH_WORKER_SERVER_SOCKET, 17 worker_main, 18) 19from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase 20 21 22class UnixHTTPConnection(HTTPConnection): 23 def __init__(self, socket_path: str) -> None: 24 super().__init__("localhost") 25 26 self.socket_path = socket_path 27 28 def connect(self) -> None: 29 self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 30 self.sock.connect(self.socket_path) 31 32 33class UnixHTTPConnectionPool(HTTPConnectionPool): 34 def __init__(self, socket_path: str) -> None: 35 super().__init__("localhost") 36 37 self.socket_path = socket_path 38 39 def _new_conn(self): 40 return UnixHTTPConnection(self.socket_path) 41 42 43@contextmanager 44def local_worker_server() -> None: 45 with tempfile.TemporaryDirectory() as tmpdir: 46 socket_path = os.path.join(tmpdir, "socket.sock") 47 os.environ[TORCH_WORKER_SERVER_SOCKET] = socket_path 48 49 with worker_main(): 50 pool = UnixHTTPConnectionPool(socket_path) 51 yield pool 52 53 54class WorkerServerTest(TestCase): 55 def test_worker_server(self) -> None: 56 with local_worker_server() as pool: 57 resp = pool.request("GET", "/") 58 self.assertEqual(resp.status, 200) 59 self.assertEqual( 60 resp.data, 61 b"""<h1>torch.distributed.WorkerServer</h1> 62<a href="/handler/">Handler names</a> 63""", 64 ) 65 66 resp = pool.request("POST", "/handler/ping") 67 self.assertEqual(resp.status, 200) 68 self.assertEqual(resp.data, b"pong") 69 70 resp = pool.request("GET", "/handler/") 71 self.assertEqual(resp.status, 200) 72 self.assertIn("ping", json.loads(resp.data)) 73 74 resp = pool.request("POST", "/handler/nonexistant") 75 self.assertEqual(resp.status, 404) 76 self.assertIn(b"Handler nonexistant not found:", resp.data) 77 78 @requires_cuda 79 def test_dump_nccl_trace_pickle(self) -> None: 80 with local_worker_server() as pool: 81 resp = pool.request("POST", "/handler/dump_nccl_trace_pickle") 82 self.assertEqual(resp.status, 200) 83 out = pickle.loads(resp.data) 84 self.assertIsInstance(out, dict) 85 self.assertIn("version", out) 86 87 @requires_cuda 88 def test_dump_nccl_trace_pickle_with_params(self) -> None: 89 with local_worker_server() as pool: 90 # bad key - not lower case 91 resp = pool.request( 92 "POST", "/handler/dump_nccl_trace_pickle?includeCollectives=true" 93 ) 94 self.assertEqual(resp.status, 400) 95 # unknown key 96 resp = pool.request( 97 "POST", "/handler/dump_nccl_trace_pickle?unknownkey=true" 98 ) 99 self.assertEqual(resp.status, 400) 100 # bad value - not a bool 101 resp = pool.request( 102 "POST", "/handler/dump_nccl_trace_pickle?includecollectives=notabool" 103 ) 104 self.assertEqual(resp.status, 400) 105 # bad value - value not lowercase 106 resp = pool.request( 107 "POST", "/handler/dump_nccl_trace_pickle?includecollectives=True" 108 ) 109 self.assertEqual(resp.status, 400) 110 # good key and value 111 resp = pool.request( 112 "POST", "/handler/dump_nccl_trace_pickle?includecollectives=true" 113 ) 114 self.assertEqual(resp.status, 200) 115 # multiple good keys and values 116 resp = pool.request( 117 "POST", 118 "/handler/dump_nccl_trace_pickle?includecollectives=true&includestacktraces=false&onlyactive=true", 119 ) 120 self.assertEqual(resp.status, 200) 121 122 @requires_cuda 123 def test_dump_nccl_trace_pickle_with_json(self) -> None: 124 with local_worker_server() as pool: 125 # bad key - not lower case 126 resp = pool.request( 127 "POST", "/handler/dump_nccl_trace_json?includeCollectives=true" 128 ) 129 self.assertEqual(resp.status, 400) 130 # unknown key 131 resp = pool.request("POST", "/handler/dump_nccl_trace_json?unknownkey=true") 132 self.assertEqual(resp.status, 400) 133 # bad value - not a bool 134 resp = pool.request( 135 "POST", "/handler/dump_nccl_trace_json?includecollectives=notabool" 136 ) 137 self.assertEqual(resp.status, 400) 138 # bad value - value not lowercase 139 resp = pool.request( 140 "POST", "/handler/dump_nccl_trace_json?includecollectives=True" 141 ) 142 self.assertEqual(resp.status, 400) 143 # good key and value 144 resp = pool.request( 145 "POST", "/handler/dump_nccl_trace_json?includecollectives=true" 146 ) 147 self.assertEqual(resp.status, 200) 148 # multiple good keys and values 149 resp = pool.request( 150 "POST", 151 "/handler/dump_nccl_trace_json?includecollectives=true&onlyactive=true", 152 ) 153 self.assertEqual(resp.status, 200) 154 155 def test_tcp(self) -> None: 156 import requests 157 158 from torch._C._distributed_c10d import _WorkerServer 159 160 server = _WorkerServer("", 1234) 161 out = requests.get("http://localhost:1234/handler/") 162 self.assertEqual(out.status_code, 200) 163 164 server.shutdown() 165 166 def test_dump_traceback(self) -> None: 167 with local_worker_server() as pool: 168 resp = pool.request("POST", "/handler/dump_traceback") 169 self.assertEqual(resp.status, 200) 170 self.assertIn(b"in test_dump_traceback\n", resp.data) 171 172 def test_run_handler(self) -> None: 173 from torch._C._distributed_c10d import _get_handler, _Request, _Response 174 175 handler = _get_handler("ping") 176 177 class Request(_Request): 178 def __init__(self) -> None: 179 _Request.__init__(self) 180 181 def body(self) -> bytes: 182 return b"dummy" 183 184 def params(self) -> Dict[str, str]: 185 return {} 186 187 class Response(_Response): 188 def __init__(self) -> None: 189 _Response.__init__(self) 190 191 def set_content(self, content: str, content_type: str) -> None: 192 self.content = content 193 self.content_type = content_type 194 195 def set_status(self, status: int) -> None: 196 self.status = status 197 198 req = Request() 199 resp = Response() 200 201 handler(req, resp) 202 203 self.assertEqual(resp.status, 200) 204 self.assertEqual(resp.content, "pong") 205 self.assertEqual(resp.content_type, "text/plain") 206 207 def test_get_handler_nonexistant(self) -> None: 208 from torch._C._distributed_c10d import _get_handler 209 210 with self.assertRaisesRegex(ValueError, "Failed to find handler nonexistant"): 211 _get_handler("nonexistant") 212 213 def test_get_handler_names(self) -> None: 214 from torch._C._distributed_c10d import _get_handler_names 215 216 names = _get_handler_names() 217 self.assertIn("ping", names) 218 219 220if __name__ == "__main__": 221 run_tests() 222