xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/request_callback_no_python.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/request_callback_no_python.h>
2 
3 #include <c10/core/StreamGuard.h>
4 #include <torch/csrc/distributed/autograd/context/container.h>
5 #include <torch/csrc/distributed/autograd/engine/dist_engine.h>
6 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
7 #include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
8 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
9 #include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
10 #include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
11 #include <torch/csrc/distributed/autograd/utils.h>
12 #include <torch/csrc/distributed/rpc/profiler/server_process_global_profiler.h>
13 #include <torch/csrc/distributed/rpc/rpc_agent.h>
14 #include <torch/csrc/distributed/rpc/rref_context.h>
15 #include <torch/csrc/distributed/rpc/rref_proto.h>
16 #include <torch/csrc/distributed/rpc/script_resp.h>
17 #include <torch/csrc/distributed/rpc/utils.h>
18 
19 #include <utility>
20 
21 namespace torch::distributed::rpc {
22 
23 using namespace torch::distributed::autograd;
24 using namespace torch::autograd::profiler;
25 
26 // When request message has autograd info, processMessage() will set up valid
27 // current context id properly. This struct is used to clean up current context
28 // id after processMessage() is done.
29 struct DistAutogradContextGuard {
DistAutogradContextGuardtorch::distributed::rpc::DistAutogradContextGuard30   explicit DistAutogradContextGuard(int64_t ctxId) {
31     auto& container = DistAutogradContainer::getInstance();
32     prevCtxId_ = container.currentContextId();
33     container.forceCurrentContextId(ctxId);
34   }
~DistAutogradContextGuardtorch::distributed::rpc::DistAutogradContextGuard35   ~DistAutogradContextGuard() {
36     auto& container = DistAutogradContainer::getInstance();
37     container.forceCurrentContextId(prevCtxId_);
38   }
39 
40   int64_t prevCtxId_;
41 };
42 
43 std::unique_ptr<RpcCommandBase> RequestCallbackNoPython::
deserializePythonRpcCommand(std::unique_ptr<RpcCommandBase> rpc,const MessageType & messageType) const44     deserializePythonRpcCommand(
45         std::unique_ptr<RpcCommandBase> rpc,
46         const MessageType& messageType) const {
47   TORCH_CHECK(
48       messageType != MessageType::PYTHON_CALL &&
49           messageType != MessageType::PYTHON_REMOTE_CALL,
50       "Python calls are not supported!");
51   return rpc;
52 }
53 
processMessage(Message & request,std::vector<c10::Stream> streams) const54 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processMessage(
55     Message& request,
56     std::vector<c10::Stream> streams) const {
57   // We need two futures here because it could pause twice when processing a
58   // RPC message:
59   //  1) waiting for all RRefs in the arguments to become confirmed;
60   //  2) waiting for processRpc to finish.
61   auto& rrefContext = RRefContext::getInstance();
62   try {
63     rrefContext.recordThreadLocalPendingRRefs();
64     // Deserialize PythonUDF here to trigger RRef unpickling
65     std::unique_ptr<RpcCommandBase> rpc = deserializePythonRpcCommand(
66         deserializeRequest(request), request.type());
67     auto rrefsReadyFuture = rrefContext.waitForThreadLocalPendingRRefs();
68 
69     auto retFuture = rrefsReadyFuture->thenAsync(
70         [this,
71          // std::function must be copyable, hence hae to cast the unique_ptr to
72          // a shared_ptr here.
73          rpc = (std::shared_ptr<RpcCommandBase>)std::move(rpc),
74          messageType = request.type(),
75          streams = std::move(streams)](JitFuture& /* unused */) mutable {
76           // The cost of pre-request check is minimal thanks to
77           // std::shared_lock. The cost is in magnitude
78           // of 10us.
79           auto serverProcessGlobalProfilerStateStackEntryPtr =
80               profiler::processglobal::StateStackEntry::current();
81           // If server global profiler is enabled, we further pay the
82           // cost of thread local profiler state initialization.
83           if (serverProcessGlobalProfilerStateStackEntryPtr) {
84             // Initialize thread-local profiler state from process-global
85             // profiler state.
86             enableProfilerLegacy(
87                 serverProcessGlobalProfilerStateStackEntryPtr->statePtr()
88                     ->config());
89           }
90 
91           auto retFuture = processRpcWithErrors(*rpc, messageType, streams);
92 
93           // Response message has been sent at this moment, this post-response
94           // work doesn't affect RPC trip time.
95           if (serverProcessGlobalProfilerStateStackEntryPtr) {
96             // Restore thread-local profiler state.
97             thread_event_lists event_lists = disableProfilerLegacy();
98             // Put thread_local event_lists into the process-global profiler
99             // state.
100             profiler::processglobal::pushResultRecursive(
101                 serverProcessGlobalProfilerStateStackEntryPtr, event_lists);
102           }
103 
104           return retFuture;
105         },
106         c10::getCustomClassType<c10::intrusive_ptr<Message>>());
107 
108     auto retFutureWithMessageId = retFuture->then(
109         [id = request.id()](JitFuture& future) {
110           c10::intrusive_ptr<Message> message =
111               future.value().toCustomClass<Message>();
112           message->setId(id);
113           return withStorages(message);
114         },
115         c10::getCustomClassType<c10::intrusive_ptr<Message>>());
116 
117     return retFutureWithMessageId;
118   } catch (std::exception& e) {
119     rrefContext.clearRecordedPendingRRefsOnError();
120     return asFuture(handleError(e, request.type(), request.id()));
121   }
122 }
123 
processRpcWithErrors(RpcCommandBase & rpc,const MessageType & messageType,const std::vector<c10::Stream> & streams) const124 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpcWithErrors(
125     RpcCommandBase& rpc,
126     const MessageType& messageType,
127     const std::vector<c10::Stream>& streams) const {
128   try {
129     return processRpc(rpc, messageType, streams);
130   } catch (std::exception& e) {
131     // Pass a dummy message ID since it will be overwritten anyways.
132     return asFuture(handleError(e, messageType, -1));
133   }
134 }
135 
processScriptCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const136 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptCall(
137     RpcCommandBase& rpc,
138     const std::vector<c10::Stream>& streams) const {
139   auto& scriptCall = static_cast<ScriptCall&>(rpc);
140 
141   TORCH_CHECK(
142       scriptCall.hasOp(), "Only supports the case where ScriptCall has an op");
143   auto future =
144       runJitOperator(*scriptCall.op(), scriptCall.stackRef(), streams);
145 
146   return future->then(
147       [](JitFuture& future) {
148         return withStorages(ScriptResp(future.value()).toMessage());
149       },
150       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
151 }
152 
processPythonCall(RpcCommandBase & rpc,const std::vector<c10::Stream> &) const153 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonCall(
154     RpcCommandBase& rpc,
155     const std::vector<c10::Stream>& /* unused */) const {
156   C10_THROW_ERROR(Error, "Python call not supported!");
157 }
158 
processPythonRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> &) const159 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processPythonRemoteCall(
160     RpcCommandBase& rpc,
161     const std::vector<c10::Stream>& /* unused */) const {
162   C10_THROW_ERROR(Error, "Python call not supported!");
163 }
164 
assignOwnerRRef(const RRefId & rrefId,const RRefId & forkId,const c10::intrusive_ptr<JitFuture> & valueFuture) const165 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::assignOwnerRRef(
166     const RRefId& rrefId,
167     const RRefId& forkId,
168     const c10::intrusive_ptr<JitFuture>& valueFuture) const {
169   auto& ctx = RRefContext::getInstance();
170 
171   c10::intrusive_ptr<OwnerRRef> ownerRRef;
172   if (rrefId == forkId) {
173     // Creating an owner RRef on self, should already exist in owners map
174     ownerRRef =
175         fromRRefInterface(ctx.getOwnerRRef(rrefId, /* forceCreated */ true)
176                               ->constValue()
177                               .toRRef());
178   } else {
179     ownerRRef = ctx.getOrCreateOwnerRRef(rrefId, valueFuture->elementType());
180     // Caller is a user and callee is the owner, add fork
181     //
182     // NB: rrefId == forkId is true if and only if calling remote to self.
183     // In that case both the caller and the callee will access the
184     // OwnerRRef. Hence, on the callee side (here), it should not call
185     // addForkOfOwner as it is not a fork. To allow callee to distinguish
186     // when this request is sent to self, the caller will set forkId using
187     // rrefId (OwnerRRef does not have a forkId anyway).
188     ctx.addForkOfOwner(rrefId, forkId);
189   }
190 
191   return valueFuture->then(
192       [ownerRRef, rrefId, forkId](JitFuture& future) {
193         if (future.hasError()) {
194           ownerRRef->setError(future.exception_ptr());
195         } else {
196           ownerRRef->setValue(future.value());
197         }
198         return withStorages(RemoteRet(rrefId, forkId).toMessage());
199       },
200       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
201 }
202 
processScriptRemoteCall(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const203 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processScriptRemoteCall(
204     RpcCommandBase& rpc,
205     const std::vector<c10::Stream>& streams) const {
206   auto& scriptRemoteCall = static_cast<ScriptRemoteCall&>(rpc);
207 
208   TORCH_CHECK(
209       scriptRemoteCall.hasOp(), "ScriptRemoteCall needs to have an op!");
210   auto future = runJitOperator(
211       *scriptRemoteCall.op(), scriptRemoteCall.stackRef(), streams);
212 
213   return assignOwnerRRef(
214       scriptRemoteCall.retRRefId(), scriptRemoteCall.retForkId(), future);
215 }
216 
retrieveOwnerRRef(const RRefId & rrefId) const217 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::retrieveOwnerRRef(
218     const RRefId& rrefId) const {
219   auto& ctx = RRefContext::getInstance();
220 
221   auto rrefFuture = ctx.getOwnerRRef(rrefId);
222 
223   at::TypePtr type = rrefFuture->elementType();
224   TORCH_INTERNAL_ASSERT(type->kind() == at::RRefType::Kind);
225   return rrefFuture->thenAsync(
226       [](JitFuture& rrefFuture) {
227         c10::intrusive_ptr<OwnerRRef> rref =
228             fromRRefInterface(rrefFuture.value().toRRef());
229         return rref->getFuture();
230       },
231       type->cast<at::RRefType>()->getElementType());
232 }
233 
234 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processScriptRRefFetchCall(RpcCommandBase & rpc) const235     processScriptRRefFetchCall(RpcCommandBase& rpc) const {
236   auto& srf = static_cast<ScriptRRefFetchCall&>(rpc);
237 
238   auto future = retrieveOwnerRRef(srf.rrefId());
239 
240   return future->then(
241       [](JitFuture& future) {
242         return withStorages(ScriptRRefFetchRet({future.value()}).toMessage());
243       },
244       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
245 }
246 
247 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processPythonRRefFetchCall(RpcCommandBase & rpc) const248     processPythonRRefFetchCall(RpcCommandBase& rpc) const {
249   C10_THROW_ERROR(Error, "Python call not supported!");
250 }
251 
processRRefUserDelete(RpcCommandBase & rpc) const252 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefUserDelete(
253     RpcCommandBase& rpc) const {
254   auto& rud = static_cast<RRefUserDelete&>(rpc);
255   auto& ctx = RRefContext::getInstance();
256   auto deletedRRef = ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
257   handleRRefDelete(deletedRRef);
258   return asFuture(RRefAck().toMessage());
259 }
260 
handleRRefDelete(c10::intrusive_ptr<RRef> & rref) const261 void RequestCallbackNoPython::handleRRefDelete(
262     c10::intrusive_ptr<RRef>& rref) const {
263   TORCH_CHECK(!rref->isPyObj(), "RRefs with python objects not supported!");
264 }
265 
processRRefChildAccept(RpcCommandBase & rpc) const266 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefChildAccept(
267     RpcCommandBase& rpc) const {
268   auto& rca = static_cast<RRefChildAccept&>(rpc);
269   auto& ctx = RRefContext::getInstance();
270   ctx.delPendingChild(rca.forkId());
271   return asFuture(RRefAck().toMessage());
272 }
273 
processRRefForkRequest(RpcCommandBase & rpc) const274 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefForkRequest(
275     RpcCommandBase& rpc) const {
276   auto& rfr = static_cast<RRefForkRequest&>(rpc);
277   auto& ctx = RRefContext::getInstance();
278   ctx.addForkOfOwnerIfNotPresent(rfr.rrefId(), rfr.forkId());
279   return asFuture(RRefAck().toMessage());
280 }
281 
282 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processForwardAutogradReq(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const283     processForwardAutogradReq(
284         RpcCommandBase& rpc,
285         const std::vector<c10::Stream>& streams) const {
286   auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
287 
288   // Need to reverse the device map for the backward pass of distributed
289   // autograd.
290   DeviceMap reverseDeviceMap;
291   for (const auto& mapEntry : rpcWithAutograd.deviceMap()) {
292     reverseDeviceMap.insert({mapEntry.second, mapEntry.first});
293   }
294 
295   // Attach 'recv' autograd function.
296   auto autogradContext = addRecvRpcBackward(
297       rpcWithAutograd.autogradMetadata(),
298       rpcWithAutograd.tensors(),
299       rpcWithAutograd.fromWorkerId(),
300       reverseDeviceMap);
301   // For this recv thread on server side, before processRpc(),
302   // set current_context_id_ to be context_id passed from client.
303   // In this way, if there is nested rpc call in python rpc call, original
304   // context_id from client can be passed in the chain calls.
305   TORCH_INTERNAL_ASSERT(
306       autogradContext != nullptr,
307       "autogradContext is nullptr, FORWARD_AUTOGRAD_REQ should always get "
308       "or create valid autogradContext in addRecvRpcBackward.");
309 
310   DistAutogradContextGuard ctxGuard(autogradContext->contextId());
311 
312   // Process the original RPC.
313   auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
314   // Kick off processing for the nested RPC command.
315   // wrappedRpcResponseFuture will be a Future<T> to the result.
316   auto wrappedRpcResponseFuture =
317       processRpc(rpcWithAutograd.wrappedRpc(), wrappedMessageType, streams);
318 
319   auto fromWorkerId = rpcWithAutograd.fromWorkerId();
320   // The original future needs to be marked as completed when the wrapped
321   // one completes, with the autograd context information wrapped.
322   auto responseFuture = wrappedRpcResponseFuture->then(
323       [fromWorkerId, ctxId = autogradContext->contextId()](
324           JitFuture& wrappedRpcResponseFuture) {
325         // As this callback can be invoked by a different thread, we have to
326         // make sure that the thread_local states in the previous thread is
327         // correctly propagated.
328         // NB: The execution of TorchScript functions can also run on a
329         // different thread, which is addressed by
330         // https://github.com/pytorch/pytorch/pull/36395
331         // NB: when adding async UDF support, we should also propagate
332         // thread_local states there.
333         // TODO: Land on a general solution for RPC ThreadLocalState. See
334         // https://github.com/pytorch/pytorch/issues/38510
335         DistAutogradContextGuard cbCtxGuard(ctxId);
336 
337         if (wrappedRpcResponseFuture.hasError()) {
338           // Propagate error to responseFuture if we had one.
339           std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
340         } else {
341           auto msg = getMessageWithAutograd(
342               fromWorkerId,
343               wrappedRpcResponseFuture.value().toCustomClass<Message>(),
344               MessageType::FORWARD_AUTOGRAD_RESP);
345           return withStorages(std::move(msg));
346         }
347       },
348       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
349 
350   return responseFuture;
351 }
352 
353 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processBackwardAutogradReq(RpcCommandBase & rpc,const std::vector<c10::Stream> & streams) const354     processBackwardAutogradReq(
355         RpcCommandBase& rpc,
356         const std::vector<c10::Stream>& streams) const {
357   c10::MultiStreamGuard guard(streams);
358   auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
359   const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
360 
361   // Retrieve the appropriate autograd context.
362   auto autogradContext = DistAutogradContainer::getInstance().retrieveContext(
363       autogradMetadata.autogradContextId);
364 
365   // Lookup the appropriate 'send' function to enqueue.
366   std::shared_ptr<SendRpcBackward> sendFunction =
367       autogradContext->retrieveSendFunction(autogradMetadata.autogradMessageId);
368 
369   // Attach the gradients to the send function.
370   sendFunction->setGrads(gradientsCall.getGrads());
371 
372   // Now execute the autograd graph using the "distributed engine."
373   auto execFuture = DistEngine::getInstance().executeSendFunctionAsync(
374       autogradContext, sendFunction, gradientsCall.retainGraph());
375 
376   // Our response is satisfied when the rpcs come back.
377   return execFuture->then(
378       [](JitFuture& execFuture) {
379         if (execFuture.hasError()) {
380           std::rethrow_exception(execFuture.exception_ptr());
381         } else {
382           return withStorages(PropagateGradientsResp().toMessage());
383         }
384       },
385       c10::getCustomClassType<c10::intrusive_ptr<Message>>());
386 }
387 
388 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processCleanupAutogradContextReq(RpcCommandBase & rpc) const389     processCleanupAutogradContextReq(RpcCommandBase& rpc) const {
390   auto& cleanupContextReq = static_cast<CleanupAutogradContextReq&>(rpc);
391   auto cleanupContextId = cleanupContextReq.getContextId();
392   // release the context if it still exists on this thread. We need to
393   // check if it exists since it may have been deleted by an in-flight
394   // RPC. This can create nested RPCs if there are other nodes that get
395   // notified to clean up their context.
396   DistAutogradContainer::getInstance().releaseContextIfPresent(
397       cleanupContextId);
398   return asFuture(CleanupAutogradContextResp().toMessage());
399 }
400 
401 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::
processRunWithProfilingReq(RpcCommandBase & rpc) const402     processRunWithProfilingReq(RpcCommandBase& rpc) const {
403   auto& rpcWithProfilingReq = static_cast<RpcWithProfilingReq&>(rpc);
404   auto wrappedMsgType = rpcWithProfilingReq.wrappedMessageType();
405   auto profilingConfig = rpcWithProfilingReq.getProfilingConfig();
406 
407   if (profilingConfig.state == ProfilerState::KINETO ||
408       profilingConfig.state == ProfilerState::KINETO_GPU_FALLBACK ||
409       profilingConfig.state == ProfilerState::KINETO_PRIVATEUSE1_FALLBACK) {
410     profilingConfig = ProfilerConfig(
411         ProfilerState::CPU,
412         profilingConfig.report_input_shapes,
413         profilingConfig.profile_memory);
414   }
415 
416   // If requested with CUDA from caller but CUDA is not available on this
417   // machine, fallback to CPU and log a warning instead of crashing.
418   if (profilingConfig.state == ProfilerState::CUDA && !this->cudaAvailable()) {
419     profilingConfig = ProfilerConfig(
420         ProfilerState::CPU,
421         profilingConfig.report_input_shapes,
422         profilingConfig.profile_memory);
423 
424     LOG(WARNING) << "Profiler was requested to be enabled with CUDA on this "
425                     "node, but CUDA is not available. "
426                  << "Falling back to CPU profiling only.";
427   }
428   TORCH_INTERNAL_ASSERT(
429       profilingConfig.state != ProfilerState::CUDA || this->cudaAvailable(),
430       "Profiler state set to CUDA but CUDA not available.");
431   const auto profilingKeyId = rpcWithProfilingReq.getProfilingId();
432   // Enable the profiler with the config from the sender.
433   // When enabling on the main thread, ensure profiler states are cleaned
434   // up, but defer consolidation of all profiled events to the continuation
435   // below.
436   ProfilerDisableOptions requestThreadOptions(
437       true /* cleanup TLS state */, false /* consolidate events */);
438   {
439     TLSLegacyProfilerGuard g(
440         profilingConfig, std::nullopt, requestThreadOptions);
441     TORCH_INTERNAL_ASSERT(
442         profilerEnabled(), "Expected profiler to be enabled!");
443     // Kick off processing for nested work and get Future<T> result in
444     // wrappedRpcResponseFuture
445     auto wrappedRpcResponseFuture = processRpc(
446         rpcWithProfilingReq.wrappedRpc(),
447         wrappedMsgType,
448         {}); // TODO: https://github.com/pytorch/pytorch/issues/55757
449 
450     auto responseFuture = wrappedRpcResponseFuture->then(
451         at::wrapPropagateTLSState([profilingKeyId, profilingConfig](
452                                       JitFuture& wrappedRpcResponseFuture) {
453           std::vector<LegacyEvent> profiledEvents;
454           // Defer consolidation of profiler events until async work has
455           // completed (such as async UDF)
456 
457           TORCH_INTERNAL_ASSERT(
458               profilerEnabled(), "Expected profiler to be enabled!");
459 
460           // On continuation thread, don't clean up profiler states, since
461           // they will be cleaned up by main thread, and consolidate all
462           // events so we obtain asynchronously run events.
463           ProfilerDisableOptions opts(false, true);
464           auto event_lists = disableProfilerLegacy(opts);
465           if (wrappedRpcResponseFuture.hasError()) {
466             // Propagate error
467             // No need to propagate remote events in the case of an error.
468             std::rethrow_exception(wrappedRpcResponseFuture.exception_ptr());
469           } else {
470             populateRemoteProfiledEvents(
471                 profiledEvents, profilingConfig, event_lists);
472             auto rpcWithProfilingResp = std::make_unique<RpcWithProfilingResp>(
473                 MessageType::RUN_WITH_PROFILING_RESP,
474                 wrappedRpcResponseFuture.value().toCustomClass<Message>(),
475                 profiledEvents,
476                 profilingKeyId);
477             return withStorages(std::move(*rpcWithProfilingResp).toMessage());
478           }
479         }),
480         c10::getCustomClassType<c10::intrusive_ptr<Message>>());
481 
482     return responseFuture;
483     // Exiting the scope will disable the profiler on this thread with the
484     // options specified above.
485   }
486 }
487 
processRRefBackward(RpcCommandBase & rpc) const488 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRRefBackward(
489     RpcCommandBase& rpc) const {
490   C10_THROW_ERROR(Error, "Python call not supported!");
491 }
492 
processRpc(RpcCommandBase & rpc,const MessageType & messageType,const std::vector<c10::Stream> & streams) const493 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::processRpc(
494     RpcCommandBase& rpc,
495     const MessageType& messageType,
496     const std::vector<c10::Stream>& streams) const {
497   // TODO: RpcCommandBase should have an abstract execute() method that we can
498   // call here instead of having another switch statement here. Even better we
499   // could have abstract classes RpcRequest and RpcResp which inherit from
500   // RpcCommandBase and RpcRequest declares the abstract method execute() that
501   // we can call here. RpcResponse could have an abstract method to convert it
502   // to a python object.
503   switch (messageType) {
504     case MessageType::SCRIPT_CALL: {
505       return processScriptCall(rpc, streams);
506     }
507     case MessageType::PYTHON_CALL: {
508       return processPythonCall(rpc, streams);
509     }
510     case MessageType::SCRIPT_REMOTE_CALL: {
511       return processScriptRemoteCall(rpc, streams);
512     }
513     case MessageType::PYTHON_REMOTE_CALL: {
514       return processPythonRemoteCall(rpc, streams);
515     }
516     case MessageType::SCRIPT_RREF_FETCH_CALL: {
517       return processScriptRRefFetchCall(rpc);
518     }
519     case MessageType::PYTHON_RREF_FETCH_CALL: {
520       return processPythonRRefFetchCall(rpc);
521     }
522     case MessageType::RREF_USER_DELETE: {
523       return processRRefUserDelete(rpc);
524     }
525     case MessageType::RREF_CHILD_ACCEPT: {
526       return processRRefChildAccept(rpc);
527     }
528     case MessageType::RREF_FORK_REQUEST: {
529       return processRRefForkRequest(rpc);
530     }
531     case MessageType::FORWARD_AUTOGRAD_REQ: {
532       return processForwardAutogradReq(rpc, streams);
533     }
534     case MessageType::BACKWARD_AUTOGRAD_REQ: {
535       return processBackwardAutogradReq(rpc, streams);
536     };
537     case MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ: {
538       return processCleanupAutogradContextReq(rpc);
539     }
540     case MessageType::RUN_WITH_PROFILING_REQ: {
541       return processRunWithProfilingReq(rpc);
542     }
543     case MessageType::RREF_BACKWARD_REQ: {
544       return processRRefBackward(rpc);
545     }
546     default: {
547       TORCH_INTERNAL_ASSERT(
548           false, "Request type ", messageType, " not supported.");
549     }
550   }
551 }
552 
handleError(const std::exception & e,const MessageType messageType,int64_t messageId) const553 c10::intrusive_ptr<Message> RequestCallbackNoPython::handleError(
554     const std::exception& e,
555     const MessageType messageType,
556     int64_t messageId) const {
557   LOG(ERROR) << "Received error while processing request type " << messageType
558              << ": " << e.what();
559   // Adding node information to the error here since all processed RPC
560   // requests should be going through this function.
561   std::string errorMsg = c10::str(
562       "Error on Node ",
563       DistAutogradContainer::getInstance().getWorkerId(),
564       ": ",
565       e.what());
566   return createExceptionResponse(errorMsg, messageId);
567 }
568 
cudaAvailable() const569 bool RequestCallbackNoPython::cudaAvailable() const {
570 #ifdef USE_CUDA
571   return true;
572 #else
573   return false;
574 #endif
575 }
576 
runJitOperator(const jit::Operator & op,std::vector<at::IValue> & stack,const std::vector<c10::Stream> & streams) const577 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::runJitOperator(
578     const jit::Operator& op,
579     std::vector<at::IValue>& stack,
580     const std::vector<c10::Stream>& streams) const {
581   c10::MultiStreamGuard guard(streams);
582   try {
583     op.getOperation()(stack);
584   } catch (const std::exception&) {
585     return asFuture(std::current_exception());
586   }
587   TORCH_INTERNAL_ASSERT(
588       stack.size() == 1,
589       "Return value of a builtin operator or a TorchScript function should be "
590       "a single IValue, got a vector of size ",
591       stack.size());
592   TypePtr type = stack.front().type();
593   return asFuture(std::move(stack.front()), std::move(type));
594 }
595 
asFuture(IValue value,TypePtr type) const596 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
597     IValue value,
598     TypePtr type) const {
599   auto future = c10::make_intrusive<JitFuture>(
600       std::move(type), RpcAgent::getCurrentRpcAgent()->getDevices());
601   future->markCompleted(std::move(value));
602   return future;
603 }
604 
asFuture(c10::intrusive_ptr<Message> message) const605 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
606     c10::intrusive_ptr<Message> message) const {
607   auto future = c10::make_intrusive<JitFuture>(
608       at::getCustomClassType<c10::intrusive_ptr<Message>>(),
609       RpcAgent::getCurrentRpcAgent()->getDevices());
610   std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> storages =
611       message->getStorages();
612   future->markCompleted(std::move(message), std::move(storages));
613   return future;
614 }
615 
asFuture(std::exception_ptr err) const616 c10::intrusive_ptr<JitFuture> RequestCallbackNoPython::asFuture(
617     std::exception_ptr err) const {
618   auto future = c10::make_intrusive<JitFuture>(
619       at::NoneType::get(), RpcAgent::getCurrentRpcAgent()->getDevices());
620   future->setError(std::move(err));
621   return future;
622 }
623 
624 } // namespace torch::distributed::rpc
625