1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/rref_impl.h> 4 #include <torch/csrc/python_headers.h> 5 #include <torch/csrc/utils/pybind.h> 6 7 namespace torch { 8 namespace distributed { 9 namespace rpc { 10 11 enum RRefProxyType { RPC_SYNC, RPC_ASYNC, REMOTE }; 12 13 // Python wrapper of an RRef shared_ptr that supports Python 14 // pickle and unpickle. 15 class PYBIND11_EXPORT PyRRef { 16 public: 17 // The first ctor can only be called while holding GIL. See its implementation 18 // for more explanations. 19 explicit PyRRef(const py::object& value, const py::object& type_hint); 20 explicit PyRRef(c10::intrusive_ptr<RRef> rref); 21 PyRRef(const PyRRef&) = default; 22 ~PyRRef(); 23 24 bool isOwner() const; 25 bool confirmedByOwner() const; 26 WorkerInfo owner() const; 27 std::string ownerName() const; 28 py::object toHere( 29 const float timeoutSeconds = 30 torch::distributed::rpc::kUnsetRpcTimeout) const; 31 py::object localValue() const; 32 std::string str() const; 33 py::tuple pickle() const; 34 static PyRRef unpickle(const py::tuple& t); 35 c10::IValue toIValue() const; 36 // Future that is associated with the creation of this RRef on the remote end. 37 // This is only used to get the future corresponding to the rref for profiling 38 // use cases. 39 c10::intrusive_ptr<JitFuture> getFuture() const; 40 // Keeps track of the future responsible for profiling owner creation 41 // acknowledgement 42 c10::intrusive_ptr<JitFuture> getProfilingFuture() const; 43 // Sets the future responsible for profiling owner creation acknowledgement. 44 // This future is set from python to be a future that returns when profiling 45 // callbacks have been run. 46 void setProfilingFuture(c10::intrusive_ptr<JitFuture> profilingFuture); 47 48 // create a proxy on this RRef, which can be used to launch RPC on the owner 49 // of this RRef to run functions on the object referenced by this RRef. 50 py::object createRRefProxy( 51 const RRefProxyType& mode, 52 float timeoutSeconds = rpc::kUnsetRpcTimeout) const; 53 54 // get the type of the data object referenced by this RRef. Timeout argument 55 // is only used in the first invocation of this function as an argument to the 56 // RPC to the owner node of the RRef. 57 py::object getRRefType( 58 float timeout = rpc::kUnsetRpcTimeout, 59 bool blocking = true); 60 61 // Run the backward pass with the RRef as the root. 62 void backward(int64_t autogradContextId, bool retainGraph); 63 64 // Helper static function to run backward on a given rref. 65 static void backward( 66 int64_t autogradContextId, 67 bool retainGraph, 68 const c10::intrusive_ptr<RRef>& rref); 69 70 // Specialization of backward if the rref is an OwnerRRef. 71 static void backwardOwnerRRef( 72 int64_t autogradContextId, 73 bool retainGraph, 74 IValue value); 75 76 private: 77 c10::intrusive_ptr<RRef> rref_; 78 std::optional<c10::intrusive_ptr<JitFuture>> profilingFuture_; 79 std::optional<py::object> type_; 80 }; 81 82 } // namespace rpc 83 } // namespace distributed 84 } // namespace torch 85