xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/py_rref.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/rref_impl.h>
4 #include <torch/csrc/python_headers.h>
5 #include <torch/csrc/utils/pybind.h>
6 
7 namespace torch {
8 namespace distributed {
9 namespace rpc {
10 
11 enum RRefProxyType { RPC_SYNC, RPC_ASYNC, REMOTE };
12 
13 // Python wrapper of an RRef shared_ptr that supports Python
14 // pickle and unpickle.
15 class PYBIND11_EXPORT PyRRef {
16  public:
17   // The first ctor can only be called while holding GIL. See its implementation
18   // for more explanations.
19   explicit PyRRef(const py::object& value, const py::object& type_hint);
20   explicit PyRRef(c10::intrusive_ptr<RRef> rref);
21   PyRRef(const PyRRef&) = default;
22   ~PyRRef();
23 
24   bool isOwner() const;
25   bool confirmedByOwner() const;
26   WorkerInfo owner() const;
27   std::string ownerName() const;
28   py::object toHere(
29       const float timeoutSeconds =
30           torch::distributed::rpc::kUnsetRpcTimeout) const;
31   py::object localValue() const;
32   std::string str() const;
33   py::tuple pickle() const;
34   static PyRRef unpickle(const py::tuple& t);
35   c10::IValue toIValue() const;
36   // Future that is associated with the creation of this RRef on the remote end.
37   // This is only used to get the future corresponding to the rref for profiling
38   // use cases.
39   c10::intrusive_ptr<JitFuture> getFuture() const;
40   // Keeps track of the future responsible for profiling owner creation
41   // acknowledgement
42   c10::intrusive_ptr<JitFuture> getProfilingFuture() const;
43   // Sets the future responsible for profiling owner creation acknowledgement.
44   // This future is set from python to be a future that returns when profiling
45   // callbacks have been run.
46   void setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture);
47 
48   // create a proxy on this RRef, which can be used to launch RPC on the owner
49   // of this RRef to run functions on the object referenced by this RRef.
50   py::object createRRefProxy(
51       const RRefProxyType& mode,
52       float timeoutSeconds = rpc::kUnsetRpcTimeout) const;
53 
54   // get the type of the data object referenced by this RRef. Timeout argument
55   // is only used in the first invocation of this function as an argument to the
56   // RPC to the owner node of the RRef.
57   py::object getRRefType(
58       float timeout = rpc::kUnsetRpcTimeout,
59       bool blocking = true);
60 
61   // Run the backward pass with the RRef as the root.
62   void backward(int64_t autogradContextId, bool retainGraph);
63 
64   // Helper static function to run backward on a given rref.
65   static void backward(
66       int64_t autogradContextId,
67       bool retainGraph,
68       const c10::intrusive_ptr<RRef>& rref);
69 
70   // Specialization of backward if the rref is an OwnerRRef.
71   static void backwardOwnerRRef(
72       int64_t autogradContextId,
73       bool retainGraph,
74       IValue value);
75 
76  private:
77   c10::intrusive_ptr<RRef> rref_;
78   std::optional<c10::intrusive_ptr<JitFuture>> profilingFuture_;
79   std::optional<py::object> type_;
80 };
81 
82 } // namespace rpc
83 } // namespace distributed
84 } // namespace torch
85