xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/python_functions.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/distributed/rpc/py_rref.h>
4 #include <torch/csrc/distributed/rpc/rpc_agent.h>
5 #include <torch/csrc/jit/python/pybind_utils.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 namespace torch::distributed::rpc {
9 
10 // Converts an internal ivalue::Future of Message into a user-facing
11 // ivalue::Future of py::object type by creating a new ivalue::Future and call
12 // its  markCompleted as a callback in the given ivalue::Future.
13 // If hasValue is true, the Message will be converted into a py::object and then
14 // wrap it with an IValue. If hasValue is false, this ivalue::Future is only
15 // used for signaling and launching callbacks. In this case, the message will be
16 // discarded and then set the ivalue::Future using an empty IValue or the given
17 // FutureError if there is an error.
18 c10::intrusive_ptr<JitFuture> toPyJitFuture(
19     const c10::intrusive_ptr<JitFuture>& messageJitFuture,
20     bool hasValue = true);
21 
22 c10::intrusive_ptr<JitFuture> pyRpcBuiltin(
23     const WorkerInfo& dst,
24     const std::string& opName,
25     const py::args& args,
26     const py::kwargs& kwargs,
27     const float rpcTimeoutSeconds);
28 
29 c10::intrusive_ptr<JitFuture> pyRpcPythonUdf(
30     const WorkerInfo& dst,
31     std::string& pickledPythonUDF,
32     std::vector<torch::Tensor>& tensors,
33     const float rpcTimeoutSeconds,
34     const bool isAsyncExecution);
35 
36 c10::intrusive_ptr<JitFuture> pyRpcTorchscript(
37     const std::string& dstWorkerName,
38     const std::string& qualifiedNameStr,
39     const py::tuple& argsTuple,
40     const py::dict& kwargsDict,
41     const float rpcTimeoutSeconds,
42     const bool isAsyncExecution);
43 
44 PyRRef pyRemoteBuiltin(
45     const WorkerInfo& dst,
46     const std::string& opName,
47     const float rpcTimeoutSeconds,
48     const py::args& args,
49     const py::kwargs& kwargs);
50 
51 PyRRef pyRemotePythonUdf(
52     const WorkerInfo& dst,
53     std::string& pickledPythonUDF,
54     std::vector<torch::Tensor>& tensors,
55     const float rpcTimeoutSeconds,
56     const bool isAsyncExecution);
57 
58 PyRRef pyRemoteTorchscript(
59     const std::string& dstWorkerName,
60     const std::string& qualifiedNameStr,
61     const float rpcTimeoutSeconds,
62     const bool isAsyncExecution,
63     const py::args& args,
64     const py::kwargs& kwargs);
65 
66 } // namespace torch::distributed::rpc
67