1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <c10/core/Event.h> 5 #include <c10/core/Stream.h> 6 #include <torch/csrc/autograd/profiler.h> 7 #include <torch/csrc/distributed/rpc/rpc_command_base.h> 8 #include <torch/csrc/jit/serialization/pickle.h> 9 #include <torch/csrc/utils/byte_order.h> 10 11 namespace torch { 12 namespace distributed { 13 namespace rpc { 14 15 // Parse error message and return RPCErrorType based on the message. 16 TORCH_API RPCErrorType getRPCErrorType(const JitFuture& jitFuture); 17 // Create an error string given the error description and error type 18 TORCH_API std::string makeRPCError( 19 const std::string& rpcErrorStr, 20 RPCErrorType errorType); 21 22 // Given an RPC message received as a request over the wire, deserialize it into 23 // the appropriate 'RpcCommandBase' type. 24 TORCH_API std::unique_ptr<RpcCommandBase> deserializeRequest( 25 const Message& request); 26 27 // Given an RPC message received as a response over the wire, deserialize it 28 // into the appropriate 'RpcCommandBase' type, if the response is 29 // FORWARD_AUTOGRAD_RESP type, unwrap it, attach recvBackward() functions 30 // to received tensors and set the wrappedMsgType to its wrapped message type. 31 TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse( 32 const Message& response, 33 MessageType& wrappedMsgType); 34 35 // Given an RPC message received as a response over the wire, deserialize it 36 // into the valid IValue if the message is for a script rpc result, 37 // otherwise deserialize it into dummy none ivalue that will never be used. 38 // In this deserialization, we also attach recv rpc backward functions if 39 // needed. 40 IValue deserializeResptoIValueInternal( 41 RpcCommandBase& rpc, 42 MessageType messageType); 43 TORCH_API IValue deserializeRespToIValue(const Message& message); 44 45 // Note: format is subject to change and intended for RPCs. 46 // For saving persistently to disk, use torch::save(). 47 TORCH_API std::string wireSerialize( 48 const std::vector<char>& payload, 49 const std::vector<at::Tensor>& tensors); 50 51 TORCH_API std::pair<std::vector<char>, std::vector<at::Tensor>> wireDeserialize( 52 const void* data, 53 size_t data_size); 54 55 // We use vector<char> as the type of blobs because it's what rpc::Message uses 56 // for its payload, even though it has the disadvantage that it cannot be 57 // allocated with uninitialized memory: it is always zeroed out. 58 59 // Some Tensors are effectively views of larger Tensors, where only a small 60 // subset of the Storage data is referenced. This normally is good and avoids 61 // copies when kept locally, but if we naively push the whole Storage over the 62 // wire, we'll end up with excess network traffic. This change clones tensors if 63 // we'd save at least half the data, and over a minimum hurdle. 64 TORCH_API c10::List<at::Tensor> cloneSparseTensors( 65 const std::vector<at::Tensor>& tensors); 66 67 // Combines an original payload and wrapped payload into the original payload. 68 // Used to generate the overall payload for the wrapped RPC. 69 TORCH_API void writeWrappedPayload( 70 std::vector<char>& originalPayload, 71 std::vector<char>& additionalPayload); 72 73 // Reads the additional, wrapped payload from a wrapped RPC off of the input 74 // payload. After this, payload will contain the payload of the original, 75 // un-wrapped RPC. 76 TORCH_API std::vector<at::IValue> readWrappedPayload( 77 std::vector<char>& payload, 78 const rpc::Message& message); 79 80 // Takes a list of events from autograd profiler and populates them into 81 // profiledEvents to be carried over RPC. 82 TORCH_API void populateRemoteProfiledEvents( 83 std::vector<torch::autograd::profiler::LegacyEvent>& profiledEvents, 84 const torch::autograd::profiler::ProfilerConfig& profilerConfig, 85 const std::vector<std::vector<torch::autograd::profiler::LegacyEvent>>& 86 eventLists); 87 88 } // namespace rpc 89 } // namespace distributed 90 } // namespace torch 91