1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_HOST_CALLBACK_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_HOST_CALLBACK_H_ 18 19 #include <atomic> 20 #include <functional> 21 #include <utility> 22 #include <vector> 23 24 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 25 26 // The following provides an API for implementing host callbacks on top of 27 // PjRT's send/recv interface (see xla::SendCallback and xla::RecvCallback). 28 // While this is not the only way to implement host callbacks using send/recv, 29 // it is provided as an example implementation that encapsulates common 30 // mechanisms for host callbacks in a framework-agnostic manner. 31 32 namespace xla { 33 34 // A thread-safe queue for passing PjRtChunk objects for e.g. from Send ops to 35 // Recv ops. 36 class ThreadSafePjRtChunkQueue { 37 public: 38 // Push a PjRtChunk into the queue. Push(PjRtChunk chunk)39 void Push(PjRtChunk chunk) { 40 absl::MutexLock lock(&mu_); 41 queue_.push_back(std::move(chunk)); 42 } 43 44 // Pop a PjRtChunk from the queue. This method blocks if the queue is empty. Pop()45 PjRtChunk Pop() { 46 absl::MutexLock lock(&mu_); 47 auto cond = [this]() { 48 mu_.AssertHeld(); 49 return !queue_.empty(); 50 }; 51 mu_.Await(absl::Condition(&cond)); 52 auto chunk = std::move(queue_.front()); 53 queue_.pop_front(); 54 return chunk; 55 } 56 57 private: 58 absl::Mutex mu_; 59 std::deque<PjRtChunk> queue_ ABSL_GUARDED_BY(mu_); 60 }; 61 62 struct HostCallbackArgInfo { 63 // The channel_id associated with this value in HLO. 64 uint16_t channel_id; 65 // The host shape for thie value. 66 Shape shape; 67 }; 68 69 struct HostCallback { 70 // The metadata (e.g. channel_id, shape) for the operands and results. 71 std::vector<HostCallbackArgInfo> operands; 72 std::vector<HostCallbackArgInfo> results; 73 74 // The host callback function takes two pointer arrays, each element of which 75 // points to allocated host buffer according to corresponding operand or 76 // result's shape. The first is for the outputs and the second is for the 77 // inputs. The buffers are only guaranteed to be alive during the call. The 78 // callback can also return error status to indicate the entire execution 79 // should fail. 80 std::function<Status(void**, void**)> callback; 81 }; 82 83 // A helper class that maintains the send/recv states for a host callback. 84 class HostCallbackContext { 85 public: HostCallbackContext(HostCallback host_callback,PjRtClient * client)86 HostCallbackContext(HostCallback host_callback, PjRtClient* client) 87 : HostCallbackContext(std::move(host_callback), 88 client->GetPjRtHostMemoryForDeviceManager()) {} 89 HostCallbackContext(HostCallback host_callback,PjRtHostMemoryForDeviceManager * host_memory_for_device_manager)90 HostCallbackContext( 91 HostCallback host_callback, 92 PjRtHostMemoryForDeviceManager* host_memory_for_device_manager) 93 : host_callback_(std::move(host_callback)), 94 host_memory_for_device_manager_(host_memory_for_device_manager), 95 args_(host_callback_.operands.size()), 96 result_channels_(host_callback_.results.size()), 97 ready_count_(args_.size()) { 98 CHECK(host_memory_for_device_manager_); 99 100 for (auto& channel : result_channels_) { 101 channel = std::make_unique<ThreadSafePjRtChunkQueue>(); 102 } 103 } 104 105 Status OnSend(int arg_num, const PjRtTransferMetadata& metadata, 106 PjRtChunk data); 107 108 void Receive(int res_num, const PjRtTransferMetadata& metadata, 109 CopyToDeviceStream& stream); 110 host_callback()111 const HostCallback& host_callback() const { return host_callback_; } 112 113 private: 114 HostCallback host_callback_; 115 PjRtHostMemoryForDeviceManager* host_memory_for_device_manager_ = nullptr; 116 std::vector<PjRtChunk> args_; 117 std::vector<std::unique_ptr<ThreadSafePjRtChunkQueue>> result_channels_; 118 std::atomic<int> ready_count_; 119 }; 120 121 // The execution states for host callbacks for all replicas. The states are kept 122 // as vectors of vectors. The outer vector corresponds to the execution 123 // replicas. The inner vector is a list of host callback states for a single 124 // execution replica. 125 struct HostCallbackStates { 126 std::vector<std::vector<std::unique_ptr<HostCallbackContext>>> contexts; 127 std::vector<std::vector<SendCallback>> send_callbacks; 128 std::vector<std::vector<RecvCallback>> recv_callbacks; 129 }; 130 131 // Creates the execution context for the `host_callback` for one replica. 132 std::unique_ptr<HostCallbackContext> 133 CreateHostCallbackStateAndAppendSendRecvCallbacks( 134 HostCallback host_callback, 135 PjRtHostMemoryForDeviceManager* host_memory_for_device_manager, 136 std::vector<SendCallback>& send_callbacks, 137 std::vector<RecvCallback>& recv_callbacks); 138 139 } // namespace xla 140 141 #endif // TENSORFLOW_COMPILER_XLA_PJRT_HOST_CALLBACK_H_ 142