xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/request_callback_no_python.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/message.h>
4 #include <torch/csrc/distributed/rpc/request_callback.h>
5 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
6 #include <torch/csrc/distributed/rpc/rref_impl.h>
7 #include <torch/csrc/distributed/rpc/script_call.h>
8 #include <torch/csrc/distributed/rpc/script_remote_call.h>
9 
10 namespace torch::distributed::rpc {
11 
12 // RequestCallback implementation with no Python dependencies.
13 class TORCH_API RequestCallbackNoPython : public RequestCallback {
14  public:
15   c10::intrusive_ptr<JitFuture> processMessage(
16       Message& request,
17       std::vector<c10::Stream> streams) const override;
18 
19  protected:
20   virtual std::unique_ptr<RpcCommandBase> deserializePythonRpcCommand(
21       std::unique_ptr<RpcCommandBase> rpc,
22       const MessageType& messageType) const;
23 
24   virtual c10::intrusive_ptr<JitFuture> processScriptCall(
25       RpcCommandBase& rpc,
26       const std::vector<c10::Stream>& streams) const;
27 
28   virtual c10::intrusive_ptr<JitFuture> processPythonCall(
29       RpcCommandBase& rpc,
30       const std::vector<c10::Stream>& streams) const;
31 
32   c10::intrusive_ptr<JitFuture> assignOwnerRRef(
33       const RRefId& rrefId,
34       const RRefId& forkId,
35       const c10::intrusive_ptr<JitFuture>& valueFuture) const;
36 
37   virtual c10::intrusive_ptr<JitFuture> processScriptRemoteCall(
38       RpcCommandBase& rpc,
39       const std::vector<c10::Stream>& streams) const;
40 
41   virtual c10::intrusive_ptr<JitFuture> processPythonRemoteCall(
42       RpcCommandBase& rpc,
43       const std::vector<c10::Stream>& streams) const;
44 
45   c10::intrusive_ptr<JitFuture> retrieveOwnerRRef(const RRefId& rrefId) const;
46 
47   c10::intrusive_ptr<JitFuture> processScriptRRefFetchCall(
48       RpcCommandBase& rpc) const;
49 
50   virtual c10::intrusive_ptr<JitFuture> processPythonRRefFetchCall(
51       RpcCommandBase& rpc) const;
52 
53   c10::intrusive_ptr<JitFuture> processRRefUserDelete(
54       RpcCommandBase& rpc) const;
55 
56   c10::intrusive_ptr<JitFuture> processRRefChildAccept(
57       RpcCommandBase& rpc) const;
58 
59   c10::intrusive_ptr<JitFuture> processRRefForkRequest(
60       RpcCommandBase& rpc) const;
61 
62   c10::intrusive_ptr<JitFuture> processForwardAutogradReq(
63       RpcCommandBase& rpc,
64       const std::vector<c10::Stream>& streams) const;
65 
66   c10::intrusive_ptr<JitFuture> processBackwardAutogradReq(
67       RpcCommandBase& rpc,
68       const std::vector<c10::Stream>& streams) const;
69 
70   c10::intrusive_ptr<JitFuture> processCleanupAutogradContextReq(
71       RpcCommandBase& rpc) const;
72 
73   c10::intrusive_ptr<JitFuture> processRunWithProfilingReq(
74       RpcCommandBase& rpc) const;
75 
76   virtual void handleRRefDelete(c10::intrusive_ptr<RRef>& rref) const;
77 
78   c10::intrusive_ptr<JitFuture> processRpc(
79       RpcCommandBase& rpc,
80       const MessageType& messageType,
81       const std::vector<c10::Stream>& streams) const;
82 
83   virtual c10::intrusive_ptr<JitFuture> processRpcWithErrors(
84       RpcCommandBase& rpc,
85       const MessageType& messageType,
86       const std::vector<c10::Stream>& streams) const;
87 
88   c10::intrusive_ptr<Message> handleError(
89       const std::exception& e,
90       const MessageType messageType,
91       int64_t messageId) const;
92 
93   virtual bool cudaAvailable() const;
94 
95   virtual c10::intrusive_ptr<JitFuture> processRRefBackward(
96       RpcCommandBase& rpc) const;
97 
98   // Helpers to run user-defined functions, operators and other computations.
99 
100   c10::intrusive_ptr<JitFuture> runJitOperator(
101       const jit::Operator& op,
102       std::vector<at::IValue>& stack,
103       const std::vector<c10::Stream>& streams) const;
104 
105   // Helpers to convert various kinds of objects into already-completed futures.
106 
107   c10::intrusive_ptr<JitFuture> asFuture(IValue value, TypePtr type) const;
108 
109   c10::intrusive_ptr<JitFuture> asFuture(
110       c10::intrusive_ptr<Message> message) const;
111 
112   c10::intrusive_ptr<JitFuture> asFuture(std::exception_ptr err) const;
113 };
114 
115 } // namespace torch::distributed::rpc
116