xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/request_callback_impl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/request_callback_no_python.h>
5 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
6 #include <torch/csrc/jit/python/pybind.h>
7 
8 namespace torch::distributed::rpc {
9 
10 class TORCH_API RequestCallbackImpl : public RequestCallbackNoPython {
11  public:
12   std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
13       std::unique_ptr<RpcCommandBase> rpc,
14       const MessageType& messageType) const override;
15 
16   c10::intrusive_ptr<JitFuture> processPythonCall(
17       RpcCommandBase& rpc,
18       const std::vector<c10::Stream>& streams) const override;
19 
20   c10::intrusive_ptr<JitFuture> processScriptCall(
21       RpcCommandBase& rpc,
22       const std::vector<c10::Stream>& streams) const override;
23 
24   c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
25       RpcCommandBase& rpc,
26       const std::vector<c10::Stream>& streams) const override;
27 
28   c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
29       RpcCommandBase& rpc,
30       const std::vector<c10::Stream>& streams) const override;
31 
32   c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall(
33       RpcCommandBase& rpc) const override;
34 
35   void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const override;
36 
37   c10::intrusive_ptr<JitFuture> processRpcWithErrors(
38       RpcCommandBase& rpc,
39       const MessageType& messageType,
40       const std::vector<c10::Stream>& streams) const override;
41 
42   bool cudaAvailable() const override;
43 
44   c10::intrusive_ptr<JitFuture> processRRefBackward(
45       RpcCommandBase& rpc) const override;
46 
47   // Helpers to run user-defined functions, operators and other computations.
48 
49   c10::intrusive_ptr<JitFuture> runJitFunction(
50       const c10::QualifiedName& name,
51       std::vector<at::IValue>& stack,
52       const std::vector<c10::Stream>& streams,
53       bool isAsyncExecution) const;
54 
55   c10::intrusive_ptr<JitFuture> runPythonFunction(
56       const py::object& function,
57       const std::vector<c10::Stream>& streams,
58       bool isAsyncExecution) const;
59 };
60 
61 } // namespace torch::distributed::rpc
62