1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/request_callback.h> 5 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 6 #include <torch/csrc/distributed/rpc/rref_impl.h> 7 #include <torch/csrc/distributed/rpc/script_call.h> 8 #include <torch/csrc/distributed/rpc/script_remote_call.h> 9 10 namespace torch::distributed::rpc { 11 12 // RequestCallback implementation with no Python dependencies. 13 class TORCH_API RequestCallbackNoPython : public RequestCallback { 14 public: 15 c10::intrusive_ptr<JitFuture> processMessage( 16 Message& request, 17 std::vector<c10::Stream> streams) const override; 18 19 protected: 20 virtual std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand( 21 std::unique_ptr<RpcCommandBase> rpc, 22 const MessageType& messageType) const; 23 24 virtual c10::intrusive_ptr<JitFuture> processScriptCall( 25 RpcCommandBase& rpc, 26 const std::vector<c10::Stream>& streams) const; 27 28 virtual c10::intrusive_ptr<JitFuture> processPythonCall( 29 RpcCommandBase& rpc, 30 const std::vector<c10::Stream>& streams) const; 31 32 c10::intrusive_ptr<JitFuture> assignOwnerRRef( 33 const RRefId& rrefId, 34 const RRefId& forkId, 35 const c10::intrusive_ptr<JitFuture>& valueFuture) const; 36 37 virtual c10::intrusive_ptr<JitFuture> processScriptRemoteCall( 38 RpcCommandBase& rpc, 39 const std::vector<c10::Stream>& streams) const; 40 41 virtual c10::intrusive_ptr<JitFuture> processPythonRemoteCall( 42 RpcCommandBase& rpc, 43 const std::vector<c10::Stream>& streams) const; 44 45 c10::intrusive_ptr<JitFuture> retrieveOwnerRRef(const RRefId& rrefId) const; 46 47 c10::intrusive_ptr<JitFuture> processScriptRRefFetchCall( 48 RpcCommandBase& rpc) const; 49 50 virtual c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall( 51 RpcCommandBase& rpc) const; 52 53 c10::intrusive_ptr<JitFuture> processRRefUserDelete( 54 RpcCommandBase& rpc) const; 55 56 c10::intrusive_ptr<JitFuture> processRRefChildAccept( 57 RpcCommandBase& rpc) const; 58 59 c10::intrusive_ptr<JitFuture> processRRefForkRequest( 60 RpcCommandBase& rpc) const; 61 62 c10::intrusive_ptr<JitFuture> processForwardAutogradReq( 63 RpcCommandBase& rpc, 64 const std::vector<c10::Stream>& streams) const; 65 66 c10::intrusive_ptr<JitFuture> processBackwardAutogradReq( 67 RpcCommandBase& rpc, 68 const std::vector<c10::Stream>& streams) const; 69 70 c10::intrusive_ptr<JitFuture> processCleanupAutogradContextReq( 71 RpcCommandBase& rpc) const; 72 73 c10::intrusive_ptr<JitFuture> processRunWithProfilingReq( 74 RpcCommandBase& rpc) const; 75 76 virtual void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const; 77 78 c10::intrusive_ptr<JitFuture> processRpc( 79 RpcCommandBase& rpc, 80 const MessageType& messageType, 81 const std::vector<c10::Stream>& streams) const; 82 83 virtual c10::intrusive_ptr<JitFuture> processRpcWithErrors( 84 RpcCommandBase& rpc, 85 const MessageType& messageType, 86 const std::vector<c10::Stream>& streams) const; 87 88 c10::intrusive_ptr<Message> handleError( 89 const std::exception& e, 90 const MessageType messageType, 91 int64_t messageId) const; 92 93 virtual bool cudaAvailable() const; 94 95 virtual c10::intrusive_ptr<JitFuture> processRRefBackward( 96 RpcCommandBase& rpc) const; 97 98 // Helpers to run user-defined functions, operators and other computations. 99 100 c10::intrusive_ptr<JitFuture> runJitOperator( 101 const jit::Operator& op, 102 std::vector<at::IValue>& stack, 103 const std::vector<c10::Stream>& streams) const; 104 105 // Helpers to convert various kinds of objects into already-completed futures. 106 107 c10::intrusive_ptr<JitFuture> asFuture(IValue value, TypePtr type) const; 108 109 c10::intrusive_ptr<JitFuture> asFuture( 110 c10::intrusive_ptr<Message> message) const; 111 112 c10::intrusive_ptr<JitFuture> asFuture(std::exception_ptr err) const; 113 }; 114 115 } // namespace torch::distributed::rpc 116