1 #include <torch/csrc/distributed/rpc/request_callback_no_python.h>
2
3 #include <c10/core/StreamGuard.h>
4 #include <torch/csrc/distributed/autograd/context/container.h>
5 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
7 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
8 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
10 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
11 #include <torch/csrc/distributed/autograd/utils.h>
12 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
13 #include <torch/csrc/distributed/rpc/rpc_agent.h>
14 #include <torch/csrc/distributed/rpc/rref_context.h>
15 #include <torch/csrc/distributed/rpc/rref_proto.h>
16 #include <torch/csrc/distributed/rpc/script_resp.h>
17 #include <torch/csrc/distributed/rpc/utils.h>
18
19 #include <utility>
20
21 namespace torch::distributed::rpc {
22
23 using namespace torch::distributed::autograd;
24 using namespace torch::autograd::profiler;
25
26 // When request message has autograd info, processMessage() will set up valid
27 // current context id properly. This struct is used to clean up current context
28 // id after processMessage() is done.
29 struct DistAutogradContextGuard {
DistAutogradContextGuardtorch::distributed::rpc::DistAutogradContextGuard30 explicit DistAutogradContextGuard(int64_t ctxId) {
31 auto& container = DistAutogradContainer::getInstance();
32 prevCtxId_ = container.currentContextId();
33 container.forceCurrentContextId(ctxId);
34 }
~DistAutogradContextGuardtorch::distributed::rpc::DistAutogradContextGuard35 ~DistAutogradContextGuard() {
36 auto& container = DistAutogradContainer::getInstance();
37 container.forceCurrentContextId(prevCtxId_);
38 }
39
40 int64_t prevCtxId_;
41 };
42
43 std::unique_ptr<RpcCommandBase> RequestCallbackNoPython::
deserializePythonRpcCommand(std::unique_ptr<RpcCommandBase> rpc,const MessageType & messageType) const44 deserializePythonRpcCommand(
45 std::unique_ptr<RpcCommandBase> rpc,
46 const MessageType& messageType) const {
47 TORCH_CHECK(
48 messageType != MessageType::PYTHON_CALL &&
49 messageType != MessageType::PYTHON_REMOTE_CALL,
50 "Python calls are not supported!");
51 return rpc;
52 }
53
processMessage(Message & request,std::vector<c10::Stream> streams) const54 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
55 Message& request,
56 std::vector<c10::Stream> streams) const {
57 // We need two futures here because it could pause twice when processing a
58 // RPC message:
59 // 1) waiting for all RRefs in the arguments to become confirmed;
60 // 2) waiting for processRpc to finish.
61 auto& rrefContext = RRefContext::getInstance();
62 try {
63 rrefContext.recordThreadLocalPendingRRefs();
64 // Deserialize PythonUDF here to trigger RRef unpickling
65 std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
66 deserializeRequest(request), request.type());
67 auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
68
69 auto retFuture = rrefsReadyFuture->thenAsync(
70 [this,
71 // std::function must be copyable, hence hae to cast the unique_ptr to
72 // a shared_ptr here.
73 rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
74 messageType = request.type(),
75 streams = std::move(streams)](JitFuture& /* unused */) mutable {
76 // The cost of pre-request check is minimal thanks to
77 // std::shared_lock. The cost is in magnitude
78 // of 10us.
79 auto serverProcessGlobalProfilerStateStackEntryPtr =
80 profiler::processglobal::StateStackEntry::current();
81 // If server global profiler is enabled, we further pay the
82 // cost of thread local profiler state initialization.
83 if (serverProcessGlobalProfilerStateStackEntryPtr) {
84 // Initialize thread-local profiler state from process-global
85 // profiler state.
86 enableProfilerLegacy(
87 serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
88 ->config());
89 }
90
91 auto retFuture = processRpcWithErrors(*rpc, messageType, streams);
92
93 // Response message has been sent at this moment, this post-response
94 // work doesn't affect RPC trip time.
95 if (serverProcessGlobalProfilerStateStackEntryPtr) {
96 // Restore thread-local profiler state.
97 thread_event_lists event_lists = disableProfilerLegacy();
98 // Put thread_local event_lists into the process-global profiler
99 // state.
100 profiler::processglobal::pushResultRecursive(
101 serverProcessGlobalProfilerStateStackEntryPtr, event_lists);
102 }
103
104 return retFuture;
105 },
106 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
107
108 auto retFutureWithMessageId = retFuture->then(
109 [id = request.id()](JitFuture& future) {
110 c10::intrusive_ptr<Message> message =
111 future.value().toCustomClass<Message>();
112 message->setId(id);
113 return withStorages(message);
114 },
115 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
116
117 return retFutureWithMessageId;
118 } catch (std::exception& e) {
119 rrefContext.clearRecordedPendingRRefsOnError();
120 return asFuture(handleError(e, request.type(), request.id()));
121 }
122 }
123
processRpcWithErrors(RpcCommandBase & rpc,const MessageType & messageType,const std::vector<c10::Stream> & streams) const124 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors(
125 RpcCommandBase& rpc,
126 const MessageType& messageType,
127 const std::vector<c10::Stream>& streams) const {
128 try {
129 return processRpc(rpc, messageType, streams);
130 } catch (std::exception& e) {
131 // Pass a dummy message ID since it will be overwritten anyways.
132 return asFuture(handleError(e, messageType, -1));
133 }
134 }
135
processScriptCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const136 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall(
137 RpcCommandBase& rpc,
138 const std::vector<c10::Stream>& streams) const {
139 auto& scriptCall = static_cast<ScriptCall&>(rpc);
140
141 TORCH_CHECK(
142 scriptCall.hasOp(), "Only supports the case where ScriptCall has an op");
143 auto future =
144 runJitOperator(*scriptCall.op(), scriptCall.stackRef(), streams);
145
146 return future->then(
147 [](JitFuture& future) {
148 return withStorages(ScriptResp(future.value()).toMessage());
149 },
150 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
151 }
152
processPythonCall(RpcCommandBase & rpc,const std::vector<c10::Stream> &) const153 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall(
154 RpcCommandBase& rpc,
155 const std::vector<c10::Stream>& /* unused */) const {
156 C10_THROW_ERROR(Error, "Python call not supported!");
157 }
158
processPythonRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> &) const159 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall(
160 RpcCommandBase& rpc,
161 const std::vector<c10::Stream>& /* unused */) const {
162 C10_THROW_ERROR(Error, "Python call not supported!");
163 }
164
assignOwnerRRef(const RRefId & rrefId,const RRefId & forkId,const c10::intrusive_ptr<JitFuture> & valueFuture) const165 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef(
166 const RRefId& rrefId,
167 const RRefId& forkId,
168 const c10::intrusive_ptr<JitFuture>& valueFuture) const {
169 auto& ctx = RRefContext::getInstance();
170
171 c10::intrusive_ptr<OwnerRRef> ownerRRef;
172 if (rrefId == forkId) {
173 // Creating an owner RRef on self, should already exist in owners map
174 ownerRRef =
175 fromRRefInterface(ctx.getOwnerRRef(rrefId, /* forceCreated */ true)
176 ->constValue()
177 .toRRef());
178 } else {
179 ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, valueFuture->elementType());
180 // Caller is a user and callee is the owner, add fork
181 //
182 // NB: rrefId == forkId is true if and only if calling remote to self.
183 // In that case both the caller and the callee will access the
184 // OwnerRRef. Hence, on the callee side (here), it should not call
185 // addForkOfOwner as it is not a fork. To allow callee to distinguish
186 // when this request is sent to self, the caller will set forkId using
187 // rrefId (OwnerRRef does not have a forkId anyway).
188 ctx.addForkOfOwner(rrefId, forkId);
189 }
190
191 return valueFuture->then(
192 [ownerRRef, rrefId, forkId](JitFuture& future) {
193 if (future.hasError()) {
194 ownerRRef->setError(future.exception_ptr());
195 } else {
196 ownerRRef->setValue(future.value());
197 }
198 return withStorages(RemoteRet(rrefId, forkId).toMessage());
199 },
200 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
201 }
202
processScriptRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const203 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall(
204 RpcCommandBase& rpc,
205 const std::vector<c10::Stream>& streams) const {
206 auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
207
208 TORCH_CHECK(
209 scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!");
210 auto future = runJitOperator(
211 *scriptRemoteCall.op(), scriptRemoteCall.stackRef(), streams);
212
213 return assignOwnerRRef(
214 scriptRemoteCall.retRRefId(), scriptRemoteCall.retForkId(), future);
215 }
216
retrieveOwnerRRef(const RRefId & rrefId) const217 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::retrieveOwnerRRef(
218 const RRefId& rrefId) const {
219 auto& ctx = RRefContext::getInstance();
220
221 auto rrefFuture = ctx.getOwnerRRef(rrefId);
222
223 at::TypePtr type = rrefFuture->elementType();
224 TORCH_INTERNAL_ASSERT(type->kind() == at::RRefType::Kind);
225 return rrefFuture->thenAsync(
226 [](JitFuture& rrefFuture) {
227 c10::intrusive_ptr<OwnerRRef> rref =
228 fromRRefInterface(rrefFuture.value().toRRef());
229 return rref->getFuture();
230 },
231 type->cast<at::RRefType>()->getElementType());
232 }
233
234 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processScriptRRefFetchCall(RpcCommandBase & rpc) const235 processScriptRRefFetchCall(RpcCommandBase& rpc) const {
236 auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
237
238 auto future = retrieveOwnerRRef(srf.rrefId());
239
240 return future->then(
241 [](JitFuture& future) {
242 return withStorages(ScriptRRefFetchRet({future.value()}).toMessage());
243 },
244 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
245 }
246
247 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processPythonRRefFetchCall(RpcCommandBase & rpc) const248 processPythonRRefFetchCall(RpcCommandBase& rpc) const {
249 C10_THROW_ERROR(Error, "Python call not supported!");
250 }
251
processRRefUserDelete(RpcCommandBase & rpc) const252 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefUserDelete(
253 RpcCommandBase& rpc) const {
254 auto& rud = static_cast<RRefUserDelete&>(rpc);
255 auto& ctx = RRefContext::getInstance();
256 auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
257 handleRRefDelete(deletedRRef);
258 return asFuture(RRefAck().toMessage());
259 }
260
handleRRefDelete(c10::intrusive_ptr<RRef> & rref) const261 void RequestCallbackNoPython::handleRRefDelete(
262 c10::intrusive_ptr<RRef>& rref) const {
263 TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!");
264 }
265
processRRefChildAccept(RpcCommandBase & rpc) const266 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefChildAccept(
267 RpcCommandBase& rpc) const {
268 auto& rca = static_cast<RRefChildAccept&>(rpc);
269 auto& ctx = RRefContext::getInstance();
270 ctx.delPendingChild(rca.forkId());
271 return asFuture(RRefAck().toMessage());
272 }
273
processRRefForkRequest(RpcCommandBase & rpc) const274 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefForkRequest(
275 RpcCommandBase& rpc) const {
276 auto& rfr = static_cast<RRefForkRequest&>(rpc);
277 auto& ctx = RRefContext::getInstance();
278 ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId());
279 return asFuture(RRefAck().toMessage());
280 }
281
282 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processForwardAutogradReq(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const283 processForwardAutogradReq(
284 RpcCommandBase& rpc,
285 const std::vector<c10::Stream>& streams) const {
286 auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
287
288 // Need to reverse the device map for the backward pass of distributed
289 // autograd.
290 DeviceMap reverseDeviceMap;
291 for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
292 reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
293 }
294
295 // Attach 'recv' autograd function.
296 auto autogradContext = addRecvRpcBackward(
297 rpcWithAutograd.autogradMetadata(),
298 rpcWithAutograd.tensors(),
299 rpcWithAutograd.fromWorkerId(),
300 reverseDeviceMap);
301 // For this recv thread on server side, before processRpc(),
302 // set current_context_id_ to be context_id passed from client.
303 // In this way, if there is nested rpc call in python rpc call, original
304 // context_id from client can be passed in the chain calls.
305 TORCH_INTERNAL_ASSERT(
306 autogradContext != nullptr,
307 "autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get "
308 "or create valid autogradContext in addRecvRpcBackward.");
309
310 DistAutogradContextGuard ctxGuard(autogradContext->contextId());
311
312 // Process the original RPC.
313 auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
314 // Kick off processing for the nested RPC command.
315 // wrappedRpcResponseFuture will be a Future<T> to the result.
316 auto wrappedRpcResponseFuture =
317 processRpc(rpcWithAutograd.wrappedRpc(), wrappedMessageType, streams);
318
319 auto fromWorkerId = rpcWithAutograd.fromWorkerId();
320 // The original future needs to be marked as completed when the wrapped
321 // one completes, with the autograd context information wrapped.
322 auto responseFuture = wrappedRpcResponseFuture->then(
323 [fromWorkerId, ctxId = autogradContext->contextId()](
324 JitFuture& wrappedRpcResponseFuture) {
325 // As this callback can be invoked by a different thread, we have to
326 // make sure that the thread_local states in the previous thread is
327 // correctly propagated.
328 // NB: The execution of TorchScript functions can also run on a
329 // different thread, which is addressed by
330 // https://github.com/pytorch/pytorch/pull/36395
331 // NB: when adding async UDF support, we should also propagate
332 // thread_local states there.
333 // TODO: Land on a general solution for RPC ThreadLocalState. See
334 // https://github.com/pytorch/pytorch/issues/38510
335 DistAutogradContextGuard cbCtxGuard(ctxId);
336
337 if (wrappedRpcResponseFuture.hasError()) {
338 // Propagate error to responseFuture if we had one.
339 std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
340 } else {
341 auto msg = getMessageWithAutograd(
342 fromWorkerId,
343 wrappedRpcResponseFuture.value().toCustomClass<Message>(),
344 MessageType::FORWARD_AUTOGRAD_RESP);
345 return withStorages(std::move(msg));
346 }
347 },
348 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
349
350 return responseFuture;
351 }
352
353 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processBackwardAutogradReq(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const354 processBackwardAutogradReq(
355 RpcCommandBase& rpc,
356 const std::vector<c10::Stream>& streams) const {
357 c10::MultiStreamGuard guard(streams);
358 auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
359 const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
360
361 // Retrieve the appropriate autograd context.
362 auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
363 autogradMetadata.autogradContextId);
364
365 // Lookup the appropriate 'send' function to enqueue.
366 std::shared_ptr<SendRpcBackward> sendFunction =
367 autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);
368
369 // Attach the gradients to the send function.
370 sendFunction->setGrads(gradientsCall.getGrads());
371
372 // Now execute the autograd graph using the "distributed engine."
373 auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
374 autogradContext, sendFunction, gradientsCall.retainGraph());
375
376 // Our response is satisfied when the rpcs come back.
377 return execFuture->then(
378 [](JitFuture& execFuture) {
379 if (execFuture.hasError()) {
380 std::rethrow_exception(execFuture.exception_ptr());
381 } else {
382 return withStorages(PropagateGradientsResp().toMessage());
383 }
384 },
385 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
386 }
387
388 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processCleanupAutogradContextReq(RpcCommandBase & rpc) const389 processCleanupAutogradContextReq(RpcCommandBase& rpc) const {
390 auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
391 auto cleanupContextId = cleanupContextReq.getContextId();
392 // release the context if it still exists on this thread. We need to
393 // check if it exists since it may have been deleted by an in-flight
394 // RPC. This can create nested RPCs if there are other nodes that get
395 // notified to clean up their context.
396 DistAutogradContainer::getInstance().releaseContextIfPresent(
397 cleanupContextId);
398 return asFuture(CleanupAutogradContextResp().toMessage());
399 }
400
401 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processRunWithProfilingReq(RpcCommandBase & rpc) const402 processRunWithProfilingReq(RpcCommandBase& rpc) const {
403 auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
404 auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType();
405 auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
406
407 if (profilingConfig.state == ProfilerState::KINETO ||
408 profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK ||
409 profilingConfig.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
410 profilingConfig = ProfilerConfig(
411 ProfilerState::CPU,
412 profilingConfig.report_input_shapes,
413 profilingConfig.profile_memory);
414 }
415
416 // If requested with CUDA from caller but CUDA is not available on this
417 // machine, fallback to CPU and log a warning instead of crashing.
418 if (profilingConfig.state == ProfilerState::CUDA && !this->cudaAvailable()) {
419 profilingConfig = ProfilerConfig(
420 ProfilerState::CPU,
421 profilingConfig.report_input_shapes,
422 profilingConfig.profile_memory);
423
424 LOG(WARNING) << "Profiler was requested to be enabled with CUDA on this "
425 "node, but CUDA is not available. "
426 << "Falling back to CPU profiling only.";
427 }
428 TORCH_INTERNAL_ASSERT(
429 profilingConfig.state != ProfilerState::CUDA || this->cudaAvailable(),
430 "Profiler state set to CUDA but CUDA not available.");
431 const auto profilingKeyId = rpcWithProfilingReq.getProfilingId();
432 // Enable the profiler with the config from the sender.
433 // When enabling on the main thread, ensure profiler states are cleaned
434 // up, but defer consolidation of all profiled events to the continuation
435 // below.
436 ProfilerDisableOptions requestThreadOptions(
437 true /* cleanup TLS state */, false /* consolidate events */);
438 {
439 TLSLegacyProfilerGuard g(
440 profilingConfig, std::nullopt, requestThreadOptions);
441 TORCH_INTERNAL_ASSERT(
442 profilerEnabled(), "Expected profiler to be enabled!");
443 // Kick off processing for nested work and get Future<T> result in
444 // wrappedRpcResponseFuture
445 auto wrappedRpcResponseFuture = processRpc(
446 rpcWithProfilingReq.wrappedRpc(),
447 wrappedMsgType,
448 {}); // TODO: https://github.com/pytorch/pytorch/issues/55757
449
450 auto responseFuture = wrappedRpcResponseFuture->then(
451 at::wrapPropagateTLSState([profilingKeyId, profilingConfig](
452 JitFuture& wrappedRpcResponseFuture) {
453 std::vector<LegacyEvent> profiledEvents;
454 // Defer consolidation of profiler events until async work has
455 // completed (such as async UDF)
456
457 TORCH_INTERNAL_ASSERT(
458 profilerEnabled(), "Expected profiler to be enabled!");
459
460 // On continuation thread, don't clean up profiler states, since
461 // they will be cleaned up by main thread, and consolidate all
462 // events so we obtain asynchronously run events.
463 ProfilerDisableOptions opts(false, true);
464 auto event_lists = disableProfilerLegacy(opts);
465 if (wrappedRpcResponseFuture.hasError()) {
466 // Propagate error
467 // No need to propagate remote events in the case of an error.
468 std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
469 } else {
470 populateRemoteProfiledEvents(
471 profiledEvents, profilingConfig, event_lists);
472 auto rpcWithProfilingResp = std::make_unique<RpcWithProfilingResp>(
473 MessageType::RUN_WITH_PROFILING_RESP,
474 wrappedRpcResponseFuture.value().toCustomClass<Message>(),
475 profiledEvents,
476 profilingKeyId);
477 return withStorages(std::move(*rpcWithProfilingResp).toMessage());
478 }
479 }),
480 c10::getCustomClassType<c10::intrusive_ptr<Message>>());
481
482 return responseFuture;
483 // Exiting the scope will disable the profiler on this thread with the
484 // options specified above.
485 }
486 }
487
processRRefBackward(RpcCommandBase & rpc) const488 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefBackward(
489 RpcCommandBase& rpc) const {
490 C10_THROW_ERROR(Error, "Python call not supported!");
491 }
492
processRpc(RpcCommandBase & rpc,const MessageType & messageType,const std::vector<c10::Stream> & streams) const493 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc(
494 RpcCommandBase& rpc,
495 const MessageType& messageType,
496 const std::vector<c10::Stream>& streams) const {
497 // TODO: RpcCommandBase should have an abstract execute() method that we can
498 // call here instead of having another switch statement here. Even better we
499 // could have abstract classes RpcRequest and RpcResp which inherit from
500 // RpcCommandBase and RpcRequest declares the abstract method execute() that
501 // we can call here. RpcResponse could have an abstract method to convert it
502 // to a python object.
503 switch (messageType) {
504 case MessageType::SCRIPT_CALL: {
505 return processScriptCall(rpc, streams);
506 }
507 case MessageType::PYTHON_CALL: {
508 return processPythonCall(rpc, streams);
509 }
510 case MessageType::SCRIPT_REMOTE_CALL: {
511 return processScriptRemoteCall(rpc, streams);
512 }
513 case MessageType::PYTHON_REMOTE_CALL: {
514 return processPythonRemoteCall(rpc, streams);
515 }
516 case MessageType::SCRIPT_RREF_FETCH_CALL: {
517 return processScriptRRefFetchCall(rpc);
518 }
519 case MessageType::PYTHON_RREF_FETCH_CALL: {
520 return processPythonRRefFetchCall(rpc);
521 }
522 case MessageType::RREF_USER_DELETE: {
523 return processRRefUserDelete(rpc);
524 }
525 case MessageType::RREF_CHILD_ACCEPT: {
526 return processRRefChildAccept(rpc);
527 }
528 case MessageType::RREF_FORK_REQUEST: {
529 return processRRefForkRequest(rpc);
530 }
531 case MessageType::FORWARD_AUTOGRAD_REQ: {
532 return processForwardAutogradReq(rpc, streams);
533 }
534 case MessageType::BACKWARD_AUTOGRAD_REQ: {
535 return processBackwardAutogradReq(rpc, streams);
536 };
537 case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
538 return processCleanupAutogradContextReq(rpc);
539 }
540 case MessageType::RUN_WITH_PROFILING_REQ: {
541 return processRunWithProfilingReq(rpc);
542 }
543 case MessageType::RREF_BACKWARD_REQ: {
544 return processRRefBackward(rpc);
545 }
546 default: {
547 TORCH_INTERNAL_ASSERT(
548 false, "Request type ", messageType, " not supported.");
549 }
550 }
551 }
552
handleError(const std::exception & e,const MessageType messageType,int64_t messageId) const553 c10::intrusive_ptr<Message> RequestCallbackNoPython::handleError(
554 const std::exception& e,
555 const MessageType messageType,
556 int64_t messageId) const {
557 LOG(ERROR) << "Received error while processing request type " << messageType
558 << ": " << e.what();
559 // Adding node information to the error here since all processed RPC
560 // requests should be going through this function.
561 std::string errorMsg = c10::str(
562 "Error on Node ",
563 DistAutogradContainer::getInstance().getWorkerId(),
564 ": ",
565 e.what());
566 return createExceptionResponse(errorMsg, messageId);
567 }
568
cudaAvailable() const569 bool RequestCallbackNoPython::cudaAvailable() const {
570 #ifdef USE_CUDA
571 return true;
572 #else
573 return false;
574 #endif
575 }
576
runJitOperator(const jit::Operator & op,std::vector<at::IValue> & stack,const std::vector<c10::Stream> & streams) const577 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::runJitOperator(
578 const jit::Operator& op,
579 std::vector<at::IValue>& stack,
580 const std::vector<c10::Stream>& streams) const {
581 c10::MultiStreamGuard guard(streams);
582 try {
583 op.getOperation()(stack);
584 } catch (const std::exception&) {
585 return asFuture(std::current_exception());
586 }
587 TORCH_INTERNAL_ASSERT(
588 stack.size() == 1,
589 "Return value of a builtin operator or a TorchScript function should be "
590 "a single IValue, got a vector of size ",
591 stack.size());
592 TypePtr type = stack.front().type();
593 return asFuture(std::move(stack.front()), std::move(type));
594 }
595
asFuture(IValue value,TypePtr type) const596 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
597 IValue value,
598 TypePtr type) const {
599 auto future = c10::make_intrusive<JitFuture>(
600 std::move(type), RpcAgent::getCurrentRpcAgent()->getDevices());
601 future->markCompleted(std::move(value));
602 return future;
603 }
604
asFuture(c10::intrusive_ptr<Message> message) const605 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
606 c10::intrusive_ptr<Message> message) const {
607 auto future = c10::make_intrusive<JitFuture>(
608 at::getCustomClassType<c10::intrusive_ptr<Message>>(),
609 RpcAgent::getCurrentRpcAgent()->getDevices());
610 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
611 message->getStorages();
612 future->markCompleted(std::move(message), std::move(storages));
613 return future;
614 }
615
asFuture(std::exception_ptr err) const616 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
617 std::exception_ptr err) const {
618 auto future = c10::make_intrusive<JitFuture>(
619 at::NoneType::get(), RpcAgent::getCurrentRpcAgent()->getDevices());
620 future->setError(std::move(err));
621 return future;
622 }
623
624 } // namespace torch::distributed::rpc
625