xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/Event.h>
5 #include <c10/core/Stream.h>
6 #include <torch/csrc/autograd/profiler.h>
7 #include <torch/csrc/distributed/rpc/rpc_command_base.h>
8 #include <torch/csrc/jit/serialization/pickle.h>
9 #include <torch/csrc/utils/byte_order.h>
10 
11 namespace torch {
12 namespace distributed {
13 namespace rpc {
14 
15 // Parse error message and return RPCErrorType based on the message.
16 TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture);
17 // Create an error string given the error description and error type
18 TORCH_API std::string makeRPCError(
19     const std::string& rpcErrorStr,
20     RPCErrorType errorType);
21 
22 // Given an RPC message received as a request over the wire, deserialize it into
23 // the appropriate 'RpcCommandBase' type.
24 TORCH_API std::unique_ptr<RpcCommandBase> deserializeRequest(
25     const Message& request);
26 
27 // Given an RPC message received as a response over the wire, deserialize it
28 // into the appropriate 'RpcCommandBase' type, if the response is
29 // FORWARD_AUTOGRAD_RESP type, unwrap it, attach recvBackward() functions
30 // to received tensors and set the wrappedMsgType to its wrapped message type.
31 TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse(
32     const Message& response,
33     MessageType& wrappedMsgType);
34 
35 // Given an RPC message received as a response over the wire, deserialize it
36 // into the valid IValue if the message is for a script rpc result,
37 // otherwise deserialize it into dummy none ivalue that will never be used.
38 // In this deserialization, we also attach recv rpc backward functions if
39 // needed.
40 IValue deserializeResptoIValueInternal(
41     RpcCommandBase& rpc,
42     MessageType messageType);
43 TORCH_API IValue deserializeRespToIValue(const Message& message);
44 
45 // Note: format is subject to change and intended for RPCs.
46 // For saving persistently to disk, use torch::save().
47 TORCH_API std::string wireSerialize(
48     const std::vector<char>& payload,
49     const std::vector<at::Tensor>& tensors);
50 
51 TORCH_API std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize(
52     const void* data,
53     size_t data_size);
54 
55 // We use vector<char> as the type of blobs because it's what rpc::Message uses
56 // for its payload, even though it has the disadvantage that it cannot be
57 // allocated with uninitialized memory: it is always zeroed out.
58 
59 // Some Tensors are effectively views of larger Tensors, where only a small
60 // subset of the Storage data is referenced. This normally is good and avoids
61 // copies when kept locally, but if we naively push the whole Storage over the
62 // wire, we'll end up with excess network traffic. This change clones tensors if
63 // we'd save at least half the data, and over a minimum hurdle.
64 TORCH_API c10::List<at::Tensor> cloneSparseTensors(
65     const std::vector<at::Tensor>& tensors);
66 
67 // Combines an original payload and wrapped payload into the original payload.
68 // Used to generate the overall payload for the wrapped RPC.
69 TORCH_API void writeWrappedPayload(
70     std::vector<char>& originalPayload,
71     std::vector<char>& additionalPayload);
72 
73 // Reads the additional, wrapped payload from a wrapped RPC off of the input
74 // payload. After this, payload will contain the payload of the original,
75 // un-wrapped RPC.
76 TORCH_API std::vector<at::IValue> readWrappedPayload(
77     std::vector<char>& payload,
78     const rpc::Message& message);
79 
80 // Takes a list of events from autograd profiler and populates them into
81 // profiledEvents to be carried over RPC.
82 TORCH_API void populateRemoteProfiledEvents(
83     std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents,
84     const torch::autograd::profiler::ProfilerConfig& profilerConfig,
85     const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>&
86         eventLists);
87 
88 } // namespace rpc
89 } // namespace distributed
90 } // namespace torch
91