1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/core/common_runtime/eager/eager_executor.h" 23 #include "tensorflow/core/common_runtime/eager/eager_operation.h" 24 #include "tensorflow/core/common_runtime/eager/tensor_handle.h" 25 #include "tensorflow/core/framework/cancellation.h" 26 #include "tensorflow/core/framework/tensor.h" 27 28 namespace tensorflow { 29 namespace eager { 30 31 // This node supports copying a tensor in the following way: 32 // - Remote -> Local: 33 // We don't block on the remote _Send op and start executing the local 34 // _Recv immediately after issuing the remote _Send. The local _Recv 35 // kernel (or rather the special _Recv handling in KernelAndDeviceOp::Run) 36 // blocks until the tensor is received. If the remote _Send (or some op 37 // before it) fails, the local callback we give to EnqueueAsync will run 38 // and call CancellationManager.StartCancel(). The blocked local _Recv will 39 // get this notification and return with a cancelled error. 40 // 41 // - Local -> Remote: 42 // The local _Send op is synchronous and non-blocking, thus it should complete 43 // quickly. We issue remote _Recv RPC only after local _Send completes 44 // successfully. At this point, the tensor to be sent is in the local 45 // Rendezvous, hence, remote _Recv op will not deadlock waiting for the tensor 46 // to appear. 47 // When ctx->UseSendTensorRPC() is true, we use EagerService::Enqueue 48 // SendTensor instead of _Send/_Recv. 49 // 50 // - Remote -> Remote: 51 // We could issue both remote ops asynchronously, but if remote _Send (or some 52 // op before it) fails, we don't have a good way of cancelling the remote 53 // _Recv. The remote _Recv will deadlock in this case. The current approach 54 // to deal with this issue is to wait for remote _Send to complete before 55 // issuing remote _Recv RPC. Another option is to close the whole streaming 56 // RPC that contains the deadlocked remote _Recv. This would not unblock the 57 // deadlocked RPC on the remote machine without some extra code. Luckily, the 58 // remote -> remote case seems to be fairly rare at this point. So, the 59 // current partially synchronous approach seems fine. 60 // 61 // To copy a tensor within a host, please use copy_to_device_node instead. 62 class RemoteCopyNode : public AsyncEagerNode { 63 public: 64 RemoteCopyNode(EagerContext* ctx, EagerExecutor* executor, TensorHandle* src, 65 TensorHandle* dst, Device* recv_device, uint64 recv_op_id); 66 67 ~RemoteCopyNode() override; 68 69 Status Prepare() override; 70 71 void RunAsync(StatusCallback done) override; 72 73 void Abort(Status status) override; 74 DebugString()75 string DebugString() const override { 76 string out = "[RemoteCopyNode]"; 77 strings::StrAppend(&out, " send_device: ", send_device_->name()); 78 strings::StrAppend(&out, ", recv_device: ", recv_device_->name()); 79 strings::StrAppend(&out, ", send_tensor: ", src_->DebugString()); 80 strings::StrAppend( 81 &out, ", recv_tensor: ", captured_state_->dst()->DebugString()); 82 return out; 83 } 84 85 private: 86 // Runs the _Send operation locally or remotely. 87 // StartSend() makes sure that captured_state_->send_status_ is set to the 88 // final _Send status after captured_state->send_done_.WaitForNotification() 89 // returns. 90 void StartSend(); 91 92 // Synchronously runs local send `op` and returns its status. 93 Status RunLocalSend(EagerOperation* op); 94 95 // Runs the _Recv operation locally or remotely. 96 // An error return value indicates that _Recv did not run successfully. It 97 // does not indicate that _Send op has completed since StartRecv could have 98 // encountered an error before waiting for _Send's completion. 99 // An OK return value does NOT necessarily indicate that _Recv has completed 100 // successfully (it does now, but won't when streaming RPCs are turned on). 101 // StartRecv() makes sure that dst_ tensor handle is handled correctly 102 // (potentially after this methods returns); a tensor is set in the local 103 // case, a remote shape is set in the remote case, the dst_ handle is 104 // poisoned in either case if there is an error. 105 void StartRecv(StatusCallback done); 106 107 // Synchronously runs local receive `op` and returns its status. 108 // Does not wait for the send to complete before running receive. 109 Status RunLocalRecv(EagerOperation* op, std::vector<Tensor>* outputs); 110 111 // Waits for send to complete, then issues remote receive `op` and 112 // returns its status. 113 void RunRemoteRecv(EagerOperation* op, StatusCallback done); 114 115 // When !ctx->UseSendTensorRPC(), then tensors are shipped between remote 116 // devices by the receiver invoking the WorkerService.RecvTensor RPC *on the 117 // sender* (Rendezvous::RecvAsync() invoked by the _Recv kernel). 118 // 119 // However, in some configurations the node that has the tensor to be copied 120 // isn't running a server (WorkerService RPC interface). For such cases, 121 // this function enables sending tensors using the EagerService.Enqueue 122 // SendTensor RPC *on the receiver*. 123 void StartRemoteSendTensor(StatusCallback done); 124 125 // Send a local packed TensorHandle to a remote device. 126 void StartSendPackedHandle(StatusCallback done); 127 128 // State that is captured by Send and/or Recv callbacks (depending on which 129 // one(s) is remote) and outlives this node in the case of remote->remote 130 // copy. 131 class CapturedSharedState { 132 public: CapturedSharedState(TensorHandle * d)133 explicit CapturedSharedState(TensorHandle* d) : dst_(d) { dst_->Ref(); } ~CapturedSharedState()134 ~CapturedSharedState() { dst_->Unref(); } 135 SetSendStatus(Status status)136 void SetSendStatus(Status status) { 137 send_status_.Update(status); 138 send_done_.Notify(); 139 } 140 GetSendStatus()141 Status GetSendStatus() { 142 send_done_.WaitForNotification(); 143 return send_status_; 144 } 145 146 // src_shape_ is not thread-safe. It should only be set in one thread. SetSrcShape(const TensorShape & shape)147 void SetSrcShape(const TensorShape& shape) { src_shape_ = shape; } 148 GetSrcShape()149 const TensorShape& GetSrcShape() { return src_shape_; } 150 dst()151 TensorHandle* dst() { return dst_; } recv_cancellation()152 CancellationManager* recv_cancellation() { return &recv_cancellation_; } 153 154 private: 155 TensorHandle* const dst_; 156 CancellationManager recv_cancellation_; 157 // send_status_ is safe to read only after send_done_.WaitForNotification() 158 // has returned. 159 Status send_status_; 160 Notification send_done_; 161 TensorShape src_shape_; 162 }; 163 164 TensorHandle* const src_; 165 EagerContext* const ctx_; 166 EagerExecutor* const executor_; 167 Device* const send_device_; 168 Device* const recv_device_; 169 const string wire_id_; 170 const uint64 recv_op_id_; 171 172 std::shared_ptr<CapturedSharedState> captured_state_; 173 bool started_; 174 }; 175 176 } // namespace eager 177 } // namespace tensorflow 178 179 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_EAGER_REMOTE_COPY_NODE_H_ 180