xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/host_callback.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/host_callback.h"
17 
18 #include <utility>
19 
20 namespace xla {
21 
OnSend(int arg_num,const PjRtTransferMetadata & metadata,PjRtChunk data)22 Status HostCallbackContext::OnSend(int arg_num,
23                                    const PjRtTransferMetadata& metadata,
24                                    PjRtChunk data) {
25   const auto& arg_info = host_callback_.operands.at(arg_num);
26   const auto& host_shape = arg_info.shape;
27   const auto& device_shape = metadata.device_shape;
28 
29   size_t host_size = ShapeUtil::ByteSizeOf(host_shape);
30   DCHECK_GE(data.size(), host_size);
31 
32   auto delinearized = PjRtChunk::AllocateDefault(host_size);
33   TF_CHECK_OK(host_memory_for_device_manager_->ToHostLayout(
34       data.data(), data.size(), device_shape, delinearized.data(),
35       delinearized.size(), host_shape));
36 
37   // This assignment to update `args_` will not race with the assignments in
38   // future send ops for this `arg_num` because send callbacks are supposed to
39   // be invoked sequentially.
40   args_.at(arg_num) = std::move(delinearized);
41 
42   DCHECK_GE(ready_count_.load(), 1);
43   if (ready_count_.fetch_sub(1) != 1) return Status::OK();
44 
45   // This atomic store won't race against the next invocation of OnSend()
46   // (e.g. by the next iteration of while loop) because send callbacks are
47   // supposed to be invoked sequentially.
48   ready_count_.store(args_.size());
49 
50   std::vector<void*> arg_ptrs;
51   arg_ptrs.reserve(args_.size());
52   for (auto& arg : args_) {
53     arg_ptrs.push_back(arg.data());
54   }
55 
56   std::vector<PjRtChunk> results;
57   std::vector<void*> result_ptrs;
58   results.reserve(result_channels_.size());
59   result_ptrs.reserve(result_channels_.size());
60   for (int i = 0; i < result_channels_.size(); ++i) {
61     const auto& host_shape = host_callback_.results.at(i).shape;
62     size_t host_size = ShapeUtil::ByteSizeOf(host_shape);
63     results.push_back(PjRtChunk::AllocateDefault(host_size));
64     result_ptrs.push_back(results.back().data());
65   }
66 
67   auto status = host_callback_.callback(result_ptrs.data(), arg_ptrs.data());
68   // TODO(chky): Consider populating garbage data in results upon errors.
69 
70   // Clear the arguments for this invocation. This won't race with next
71   // invocation as send callbacks are supposed to be invoked sequentially.
72   for (auto& arg : args_) {
73     arg = PjRtChunk{};
74   }
75 
76   // Sending the results to recv callbacks if there is any. Note that after
77   // this point, this callback can be invoked again (e.g. in a loop) anytime.
78   for (int i = 0; i < result_channels_.size(); ++i) {
79     auto& result_channel = result_channels_[i];
80     result_channel->Push(std::move(results[i]));
81   }
82 
83   return status;
84 }
85 
Receive(int res_num,const PjRtTransferMetadata & metadata,CopyToDeviceStream & stream)86 void HostCallbackContext::Receive(int res_num,
87                                   const PjRtTransferMetadata& metadata,
88                                   CopyToDeviceStream& stream) {
89   auto& result_channel = result_channels_.at(res_num);
90   PjRtChunk chunk = result_channel->Pop();
91 
92   const auto& host_shape = host_callback_.results.at(res_num).shape;
93   const auto& device_shape = metadata.device_shape;
94 
95   auto statusor_linearized = host_memory_for_device_manager_->ToDeviceLayout(
96       chunk.data(), chunk.size(), host_shape, device_shape);
97   TF_CHECK_OK(stream.AddChunk(std::move(statusor_linearized).value()));
98 }
99 
100 std::unique_ptr<HostCallbackContext>
CreateHostCallbackStateAndAppendSendRecvCallbacks(HostCallback host_callback,PjRtHostMemoryForDeviceManager * host_memory_for_device_manager,std::vector<SendCallback> & send_callbacks,std::vector<RecvCallback> & recv_callbacks)101 CreateHostCallbackStateAndAppendSendRecvCallbacks(
102     HostCallback host_callback,
103     PjRtHostMemoryForDeviceManager* host_memory_for_device_manager,
104     std::vector<SendCallback>& send_callbacks,
105     std::vector<RecvCallback>& recv_callbacks) {
106   auto context = std::make_unique<HostCallbackContext>(
107       std::move(host_callback), host_memory_for_device_manager);
108 
109   const auto& hb = context->host_callback();
110   for (int arg_num = 0; arg_num < hb.operands.size(); ++arg_num) {
111     const auto& operand_info = hb.operands[arg_num];
112     send_callbacks.push_back(SendCallback{
113         /*channel_id=*/operand_info.channel_id,
114         /*callback=*/[arg_num, context = context.get()](
115                          const PjRtTransferMetadata& metadata, PjRtChunk input,
116                          size_t total_size_in_bytes, bool done) {
117           return context->OnSend(arg_num, metadata, std::move(input));
118         }});
119   }
120 
121   for (int res_num = 0; res_num < hb.results.size(); ++res_num) {
122     const auto& result_info = hb.results[res_num];
123     recv_callbacks.push_back(
124         RecvCallback{/*channel_id=*/result_info.channel_id,
125                      /*callback=*/[res_num, context = context.get()](
126                                       const PjRtTransferMetadata& metadata,
127                                       CopyToDeviceStream& stream) {
128                        context->Receive(res_num, metadata, stream);
129                      }});
130   }
131 
132   return context;
133 }
134 
135 }  // namespace xla
136