1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/message.h> 4 #include <torch/csrc/distributed/rpc/request_callback_no_python.h> 5 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 6 #include <torch/csrc/jit/python/pybind.h> 7 8 namespace torch::distributed::rpc { 9 10 class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython { 11 public: 12 std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand( 13 std::unique_ptr<RpcCommandBase> rpc, 14 const MessageType& messageType) const override; 15 16 c10::intrusive_ptr<JitFuture> processPythonCall( 17 RpcCommandBase& rpc, 18 const std::vector<c10::Stream>& streams) const override; 19 20 c10::intrusive_ptr<JitFuture> processScriptCall( 21 RpcCommandBase& rpc, 22 const std::vector<c10::Stream>& streams) const override; 23 24 c10::intrusive_ptr<JitFuture> processScriptRemoteCall( 25 RpcCommandBase& rpc, 26 const std::vector<c10::Stream>& streams) const override; 27 28 c10::intrusive_ptr<JitFuture> processPythonRemoteCall( 29 RpcCommandBase& rpc, 30 const std::vector<c10::Stream>& streams) const override; 31 32 c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall( 33 RpcCommandBase& rpc) const override; 34 35 void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override; 36 37 c10::intrusive_ptr<JitFuture> processRpcWithErrors( 38 RpcCommandBase& rpc, 39 const MessageType& messageType, 40 const std::vector<c10::Stream>& streams) const override; 41 42 bool cudaAvailable() const override; 43 44 c10::intrusive_ptr<JitFuture> processRRefBackward( 45 RpcCommandBase& rpc) const override; 46 47 // Helpers to run user-defined functions, operators and other computations. 48 49 c10::intrusive_ptr<JitFuture> runJitFunction( 50 const c10::QualifiedName& name, 51 std::vector<at::IValue>& stack, 52 const std::vector<c10::Stream>& streams, 53 bool isAsyncExecution) const; 54 55 c10::intrusive_ptr<JitFuture> runPythonFunction( 56 const py::object& function, 57 const std::vector<c10::Stream>& streams, 58 bool isAsyncExecution) const; 59 }; 60 61 } // namespace torch::distributed::rpc 62