xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/torchscript_functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ThreadLocalState.h>
2 #include <fmt/format.h>
3 #include <torch/csrc/autograd/record_function_ops.h>
4 #include <torch/csrc/distributed/autograd/utils.h>
5 #include <torch/csrc/distributed/rpc/message.h>
6 #include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
7 #include <torch/csrc/distributed/rpc/rpc_agent.h>
8 #include <torch/csrc/distributed/rpc/rref_proto.h>
9 #include <torch/csrc/distributed/rpc/script_call.h>
10 #include <torch/csrc/distributed/rpc/torchscript_functions.h>
11 #include <torch/csrc/distributed/rpc/utils.h>
12 
13 namespace torch::distributed::rpc {
14 
rpcTorchscript(const std::string & dstWorkerName,const c10::QualifiedName & qualifiedName,const c10::FunctionSchema & functionSchema,std::vector<c10::IValue> & stack,const float rpcTimeoutSeconds,const bool isAsyncExecution)15 c10::intrusive_ptr<JitFuture> rpcTorchscript(
16     const std::string& dstWorkerName,
17     const c10::QualifiedName& qualifiedName,
18     const c10::FunctionSchema& functionSchema,
19     std::vector<c10::IValue>& stack,
20     const float rpcTimeoutSeconds,
21     const bool isAsyncExecution) {
22   c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record;
23   auto shouldProfile = torch::autograd::profiler::profilerEnabled() &&
24       !torch::distributed::rpc::RemoteProfilerManager::getInstance()
25            .isCurrentKeySet();
26   if (shouldProfile) {
27     auto rpcAsyncJitKey = fmt::format(
28         "rpc_async_jit#{}({} -> {})",
29         qualifiedName
30             .qualifiedName(), /* name of torchscript function being run */
31         RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
32         dstWorkerName);
33     record =
34         torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey);
35     auto& remoteProfilerManager =
36         torch::distributed::rpc::RemoteProfilerManager::getInstance();
37     remoteProfilerManager.setCurrentKey(rpcAsyncJitKey);
38   }
39   auto scriptCall = std::make_unique<ScriptCall>(
40       qualifiedName, std::move(stack), isAsyncExecution);
41   auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
42   auto jitFuture = autograd::sendMessageWithAutograd(
43       *rpcAgentPtr,
44       rpcAgentPtr->getWorkerInfo(dstWorkerName),
45       std::move(*scriptCall).toMessage(),
46       true /*forceGradRecording*/,
47       rpcTimeoutSeconds);
48 
49   // Get function return type to construct JitFuture.
50   auto returns = functionSchema.returns();
51   // Script call only allows single IValue returned.
52   TORCH_INTERNAL_ASSERT(
53       returns.size() == 1,
54       "Return value of an annotated torchScript function should be a single "
55       "IValue.",
56       returns.size());
57   auto returnType = returns.at(0).type();
58 
59   // Create a JIT future and pass it to futMessage's callback to set state
60   // of the JIT future.
61   auto futPtr = jitFuture->createInstance(returnType);
62   jitFuture->addCallback(at::wrapPropagateTLSState([futPtr](JitFuture& future) {
63     if (future.hasError()) {
64       futPtr->setError(future.exception_ptr());
65     } else {
66       futPtr->markCompleted(
67           deserializeRespToIValue(
68               *future.constValue().toCustomClass<Message>()),
69           future.storages());
70     }
71   }));
72   if (shouldProfile) {
73     auto profiledFutPtr =
74         torch::autograd::profiler::_call_end_callbacks_on_fut_new(
75             record, futPtr);
76     return profiledFutPtr;
77   }
78   return futPtr;
79 }
80 
remoteTorchscript(const std::string & dstWorkerName,const c10::QualifiedName & qualifiedName,const c10::FunctionSchema & functionSchema,std::vector<c10::IValue> & stack,const float rpcTimeoutSeconds,const bool isAsyncExecution)81 c10::intrusive_ptr<RRef> remoteTorchscript(
82     const std::string& dstWorkerName,
83     const c10::QualifiedName& qualifiedName,
84     const c10::FunctionSchema& functionSchema,
85     std::vector<c10::IValue>& stack,
86     const float rpcTimeoutSeconds,
87     const bool isAsyncExecution) {
88   auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
89   auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
90   auto& ctx = RRefContext::getInstance();
91 
92   // Get function return type to construct UserRRef.
93   auto returns = functionSchema.returns();
94   // Script call only allows single IValue returned.
95   TORCH_INTERNAL_ASSERT(
96       returns.size() == 1,
97       "Return value of an annotated torchScript function should be a single "
98       "IValue.",
99       returns.size());
100   auto returnType = returns.at(0).type();
101 
102   if (ctx.getWorkerId() != dstWorkerInfo.id_) {
103     auto userRRefPtr = ctx.createUserRRef(dstWorkerInfo.id_, returnType);
104 
105     auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
106         qualifiedName,
107         std::move(stack),
108         userRRefPtr->rrefId(),
109         userRRefPtr->forkId(),
110         isAsyncExecution);
111 
112     auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
113         *rpcAgentPtr,
114         dstWorkerInfo,
115         std::move(*scriptRemoteCall).toMessage(),
116         true /*forceGradRecording*/,
117         rpcTimeoutSeconds /* timeout */);
118 
119     userRRefPtr->registerOwnerCreationFuture(jitFuture);
120     ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr);
121     jitFuture->addCallback(at::wrapPropagateTLSState(
122         [forkId{userRRefPtr->forkId()}](JitFuture& future) {
123           callback::confirmPendingUser(future, forkId);
124         }));
125 
126     return userRRefPtr;
127   } else {
128     auto ownerRRefPtr = ctx.createOwnerRRef(returnType);
129     // prevent this owner RRef from being deleted due to other forks
130     ctx.addSelfAsFork(ownerRRefPtr);
131 
132     auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
133         qualifiedName,
134         std::move(stack),
135         ownerRRefPtr->rrefId(),
136         ownerRRefPtr->rrefId(),
137         isAsyncExecution);
138 
139     auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
140         *rpcAgentPtr,
141         dstWorkerInfo,
142         std::move(*scriptRemoteCall).toMessage(),
143         true /*forceGradRecording*/,
144         rpcTimeoutSeconds /* timeout */);
145 
146     ownerRRefPtr->registerOwnerCreationFuture(jitFuture);
147     jitFuture->addCallback(at::wrapPropagateTLSState(
148         [ownerRRefId = ownerRRefPtr->rrefId()](JitFuture& future) {
149           callback::finishCreatingOwnerRRef(future, ownerRRefId);
150         }));
151     return ownerRRefPtr;
152   }
153 }
154 
155 } // namespace torch::distributed::rpc
156