1 #pragma once 2 3 #include <torch/csrc/distributed/rpc/py_rref.h> 4 #include <torch/csrc/distributed/rpc/rpc_agent.h> 5 #include <torch/csrc/jit/python/pybind_utils.h> 6 #include <torch/csrc/utils/pybind.h> 7 8 namespace torch::distributed::rpc { 9 10 // Converts an internal ivalue::Future of Message into a user-facing 11 // ivalue::Future of py::object type by creating a new ivalue::Future and call 12 // its markCompleted as a callback in the given ivalue::Future. 13 // If hasValue is true, the Message will be converted into a py::object and then 14 // wrap it with an IValue. If hasValue is false, this ivalue::Future is only 15 // used for signaling and launching callbacks. In this case, the message will be 16 // discarded and then set the ivalue::Future using an empty IValue or the given 17 // FutureError if there is an error. 18 c10::intrusive_ptr<JitFuture> toPyJitFuture( 19 const c10::intrusive_ptr<JitFuture>& messageJitFuture, 20 bool hasValue = true); 21 22 c10::intrusive_ptr<JitFuture> pyRpcBuiltin( 23 const WorkerInfo& dst, 24 const std::string& opName, 25 const py::args& args, 26 const py::kwargs& kwargs, 27 const float rpcTimeoutSeconds); 28 29 c10::intrusive_ptr<JitFuture> pyRpcPythonUdf( 30 const WorkerInfo& dst, 31 std::string& pickledPythonUDF, 32 std::vector<torch::Tensor>& tensors, 33 const float rpcTimeoutSeconds, 34 const bool isAsyncExecution); 35 36 c10::intrusive_ptr<JitFuture> pyRpcTorchscript( 37 const std::string& dstWorkerName, 38 const std::string& qualifiedNameStr, 39 const py::tuple& argsTuple, 40 const py::dict& kwargsDict, 41 const float rpcTimeoutSeconds, 42 const bool isAsyncExecution); 43 44 PyRRef pyRemoteBuiltin( 45 const WorkerInfo& dst, 46 const std::string& opName, 47 const float rpcTimeoutSeconds, 48 const py::args& args, 49 const py::kwargs& kwargs); 50 51 PyRRef pyRemotePythonUdf( 52 const WorkerInfo& dst, 53 std::string& pickledPythonUDF, 54 std::vector<torch::Tensor>& tensors, 55 const float rpcTimeoutSeconds, 56 const bool isAsyncExecution); 57 58 PyRRef pyRemoteTorchscript( 59 const std::string& dstWorkerName, 60 const std::string& qualifiedNameStr, 61 const float rpcTimeoutSeconds, 62 const bool isAsyncExecution, 63 const py::args& args, 64 const py::kwargs& kwargs); 65 66 } // namespace torch::distributed::rpc 67