xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/host_callback.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_HOST_CALLBACK_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_HOST_CALLBACK_H_
18 
19 #include <atomic>
20 #include <functional>
21 #include <utility>
22 #include <vector>
23 
24 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
25 
26 // The following provides an API for implementing host callbacks on top of
27 // PjRT's send/recv interface (see xla::SendCallback and xla::RecvCallback).
28 // While this is not the only way to implement host callbacks using send/recv,
29 // it is provided as an example implementation that encapsulates common
30 // mechanisms for host callbacks in a framework-agnostic manner.
31 
32 namespace xla {
33 
34 // A thread-safe queue for passing PjRtChunk objects for e.g. from Send ops to
35 // Recv ops.
36 class ThreadSafePjRtChunkQueue {
37  public:
38   // Push a PjRtChunk into the queue.
Push(PjRtChunk chunk)39   void Push(PjRtChunk chunk) {
40     absl::MutexLock lock(&mu_);
41     queue_.push_back(std::move(chunk));
42   }
43 
44   // Pop a PjRtChunk from the queue. This method blocks if the queue is empty.
Pop()45   PjRtChunk Pop() {
46     absl::MutexLock lock(&mu_);
47     auto cond = [this]() {
48       mu_.AssertHeld();
49       return !queue_.empty();
50     };
51     mu_.Await(absl::Condition(&cond));
52     auto chunk = std::move(queue_.front());
53     queue_.pop_front();
54     return chunk;
55   }
56 
57  private:
58   absl::Mutex mu_;
59   std::deque<PjRtChunk> queue_ ABSL_GUARDED_BY(mu_);
60 };
61 
62 struct HostCallbackArgInfo {
63   // The channel_id associated with this value in HLO.
64   uint16_t channel_id;
65   // The host shape for thie value.
66   Shape shape;
67 };
68 
69 struct HostCallback {
70   // The metadata (e.g. channel_id, shape) for the operands and results.
71   std::vector<HostCallbackArgInfo> operands;
72   std::vector<HostCallbackArgInfo> results;
73 
74   // The host callback function takes two pointer arrays, each element of which
75   // points to allocated host buffer according to corresponding operand or
76   // result's shape. The first is for the outputs and the second is for the
77   // inputs. The buffers are only guaranteed to be alive during the call. The
78   // callback can also return error status to indicate the entire execution
79   // should fail.
80   std::function<Status(void**, void**)> callback;
81 };
82 
83 // A helper class that maintains the send/recv states for a host callback.
84 class HostCallbackContext {
85  public:
HostCallbackContext(HostCallback host_callback,PjRtClient * client)86   HostCallbackContext(HostCallback host_callback, PjRtClient* client)
87       : HostCallbackContext(std::move(host_callback),
88                             client->GetPjRtHostMemoryForDeviceManager()) {}
89 
HostCallbackContext(HostCallback host_callback,PjRtHostMemoryForDeviceManager * host_memory_for_device_manager)90   HostCallbackContext(
91       HostCallback host_callback,
92       PjRtHostMemoryForDeviceManager* host_memory_for_device_manager)
93       : host_callback_(std::move(host_callback)),
94         host_memory_for_device_manager_(host_memory_for_device_manager),
95         args_(host_callback_.operands.size()),
96         result_channels_(host_callback_.results.size()),
97         ready_count_(args_.size()) {
98     CHECK(host_memory_for_device_manager_);
99 
100     for (auto& channel : result_channels_) {
101       channel = std::make_unique<ThreadSafePjRtChunkQueue>();
102     }
103   }
104 
105   Status OnSend(int arg_num, const PjRtTransferMetadata& metadata,
106                 PjRtChunk data);
107 
108   void Receive(int res_num, const PjRtTransferMetadata& metadata,
109                CopyToDeviceStream& stream);
110 
host_callback()111   const HostCallback& host_callback() const { return host_callback_; }
112 
113  private:
114   HostCallback host_callback_;
115   PjRtHostMemoryForDeviceManager* host_memory_for_device_manager_ = nullptr;
116   std::vector<PjRtChunk> args_;
117   std::vector<std::unique_ptr<ThreadSafePjRtChunkQueue>> result_channels_;
118   std::atomic<int> ready_count_;
119 };
120 
121 // The execution states for host callbacks for all replicas. The states are kept
122 // as vectors of vectors. The outer vector corresponds to the execution
123 // replicas. The inner vector is a list of host callback states for a single
124 // execution replica.
125 struct HostCallbackStates {
126   std::vector<std::vector<std::unique_ptr<HostCallbackContext>>> contexts;
127   std::vector<std::vector<SendCallback>> send_callbacks;
128   std::vector<std::vector<RecvCallback>> recv_callbacks;
129 };
130 
131 // Creates the execution context for the `host_callback` for one replica.
132 std::unique_ptr<HostCallbackContext>
133 CreateHostCallbackStateAndAppendSendRecvCallbacks(
134     HostCallback host_callback,
135     PjRtHostMemoryForDeviceManager* host_memory_for_device_manager,
136     std::vector<SendCallback>& send_callbacks,
137     std::vector<RecvCallback>& recv_callbacks);
138 
139 }  // namespace xla
140 
141 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_HOST_CALLBACK_H_
142