1 #include <ATen/ThreadLocalState.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/autograd/record_function_ops.h>
4 #include <torch/csrc/distributed/autograd/utils.h>
5 #include <torch/csrc/distributed/rpc/message.h>
6 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
7 #include <torch/csrc/distributed/rpc/rpc_agent.h>
8 #include <torch/csrc/distributed/rpc/rref_proto.h>
9 #include <torch/csrc/distributed/rpc/script_call.h>
10 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
11 #include <torch/csrc/distributed/rpc/utils.h>
12
13 namespace torch::distributed::rpc {
14
rpcTorchscript(const std::string & dstWorkerName,const c10::QualifiedName & qualifiedName,const c10::FunctionSchema & functionSchema,std::vector<c10::IValue> & stack,const float rpcTimeoutSeconds,const bool isAsyncExecution)15 c10::intrusive_ptr<JitFuture> rpcTorchscript(
16 const std::string& dstWorkerName,
17 const c10::QualifiedName& qualifiedName,
18 const c10::FunctionSchema& functionSchema,
19 std::vector<c10::IValue>& stack,
20 const float rpcTimeoutSeconds,
21 const bool isAsyncExecution) {
22 c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record;
23 auto shouldProfile = torch::autograd::profiler::profilerEnabled() &&
24 !torch::distributed::rpc::RemoteProfilerManager::getInstance()
25 .isCurrentKeySet();
26 if (shouldProfile) {
27 auto rpcAsyncJitKey = fmt::format(
28 "rpc_async_jit#{}({} -> {})",
29 qualifiedName
30 .qualifiedName(), /* name of torchscript function being run */
31 RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
32 dstWorkerName);
33 record =
34 torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey);
35 auto& remoteProfilerManager =
36 torch::distributed::rpc::RemoteProfilerManager::getInstance();
37 remoteProfilerManager.setCurrentKey(rpcAsyncJitKey);
38 }
39 auto scriptCall = std::make_unique<ScriptCall>(
40 qualifiedName, std::move(stack), isAsyncExecution);
41 auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
42 auto jitFuture = autograd::sendMessageWithAutograd(
43 *rpcAgentPtr,
44 rpcAgentPtr->getWorkerInfo(dstWorkerName),
45 std::move(*scriptCall).toMessage(),
46 true /*forceGradRecording*/,
47 rpcTimeoutSeconds);
48
49 // Get function return type to construct JitFuture.
50 auto returns = functionSchema.returns();
51 // Script call only allows single IValue returned.
52 TORCH_INTERNAL_ASSERT(
53 returns.size() == 1,
54 "Return value of an annotated torchScript function should be a single "
55 "IValue.",
56 returns.size());
57 auto returnType = returns.at(0).type();
58
59 // Create a JIT future and pass it to futMessage's callback to set state
60 // of the JIT future.
61 auto futPtr = jitFuture->createInstance(returnType);
62 jitFuture->addCallback(at::wrapPropagateTLSState([futPtr](JitFuture& future) {
63 if (future.hasError()) {
64 futPtr->setError(future.exception_ptr());
65 } else {
66 futPtr->markCompleted(
67 deserializeRespToIValue(
68 *future.constValue().toCustomClass<Message>()),
69 future.storages());
70 }
71 }));
72 if (shouldProfile) {
73 auto profiledFutPtr =
74 torch::autograd::profiler::_call_end_callbacks_on_fut_new(
75 record, futPtr);
76 return profiledFutPtr;
77 }
78 return futPtr;
79 }
80
remoteTorchscript(const std::string & dstWorkerName,const c10::QualifiedName & qualifiedName,const c10::FunctionSchema & functionSchema,std::vector<c10::IValue> & stack,const float rpcTimeoutSeconds,const bool isAsyncExecution)81 c10::intrusive_ptr<RRef> remoteTorchscript(
82 const std::string& dstWorkerName,
83 const c10::QualifiedName& qualifiedName,
84 const c10::FunctionSchema& functionSchema,
85 std::vector<c10::IValue>& stack,
86 const float rpcTimeoutSeconds,
87 const bool isAsyncExecution) {
88 auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
89 auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
90 auto& ctx = RRefContext::getInstance();
91
92 // Get function return type to construct UserRRef.
93 auto returns = functionSchema.returns();
94 // Script call only allows single IValue returned.
95 TORCH_INTERNAL_ASSERT(
96 returns.size() == 1,
97 "Return value of an annotated torchScript function should be a single "
98 "IValue.",
99 returns.size());
100 auto returnType = returns.at(0).type();
101
102 if (ctx.getWorkerId() != dstWorkerInfo.id_) {
103 auto userRRefPtr = ctx.createUserRRef(dstWorkerInfo.id_, returnType);
104
105 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
106 qualifiedName,
107 std::move(stack),
108 userRRefPtr->rrefId(),
109 userRRefPtr->forkId(),
110 isAsyncExecution);
111
112 auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
113 *rpcAgentPtr,
114 dstWorkerInfo,
115 std::move(*scriptRemoteCall).toMessage(),
116 true /*forceGradRecording*/,
117 rpcTimeoutSeconds /* timeout */);
118
119 userRRefPtr->registerOwnerCreationFuture(jitFuture);
120 ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr);
121 jitFuture->addCallback(at::wrapPropagateTLSState(
122 [forkId{userRRefPtr->forkId()}](JitFuture& future) {
123 callback::confirmPendingUser(future, forkId);
124 }));
125
126 return userRRefPtr;
127 } else {
128 auto ownerRRefPtr = ctx.createOwnerRRef(returnType);
129 // prevent this owner RRef from being deleted due to other forks
130 ctx.addSelfAsFork(ownerRRefPtr);
131
132 auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
133 qualifiedName,
134 std::move(stack),
135 ownerRRefPtr->rrefId(),
136 ownerRRefPtr->rrefId(),
137 isAsyncExecution);
138
139 auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
140 *rpcAgentPtr,
141 dstWorkerInfo,
142 std::move(*scriptRemoteCall).toMessage(),
143 true /*forceGradRecording*/,
144 rpcTimeoutSeconds /* timeout */);
145
146 ownerRRefPtr->registerOwnerCreationFuture(jitFuture);
147 jitFuture->addCallback(at::wrapPropagateTLSState(
148 [ownerRRefId = ownerRRefPtr->rrefId()](JitFuture& future) {
149 callback::finishCreatingOwnerRRef(future, ownerRRefId);
150 }));
151 return ownerRRefPtr;
152 }
153 }
154
155 } // namespace torch::distributed::rpc
156