1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/host_callback.h"
17
18 #include <utility>
19
20 namespace xla {
21
OnSend(int arg_num,const PjRtTransferMetadata & metadata,PjRtChunk data)22 Status HostCallbackContext::OnSend(int arg_num,
23 const PjRtTransferMetadata& metadata,
24 PjRtChunk data) {
25 const auto& arg_info = host_callback_.operands.at(arg_num);
26 const auto& host_shape = arg_info.shape;
27 const auto& device_shape = metadata.device_shape;
28
29 size_t host_size = ShapeUtil::ByteSizeOf(host_shape);
30 DCHECK_GE(data.size(), host_size);
31
32 auto delinearized = PjRtChunk::AllocateDefault(host_size);
33 TF_CHECK_OK(host_memory_for_device_manager_->ToHostLayout(
34 data.data(), data.size(), device_shape, delinearized.data(),
35 delinearized.size(), host_shape));
36
37 // This assignment to update `args_` will not race with the assignments in
38 // future send ops for this `arg_num` because send callbacks are supposed to
39 // be invoked sequentially.
40 args_.at(arg_num) = std::move(delinearized);
41
42 DCHECK_GE(ready_count_.load(), 1);
43 if (ready_count_.fetch_sub(1) != 1) return Status::OK();
44
45 // This atomic store won't race against the next invocation of OnSend()
46 // (e.g. by the next iteration of while loop) because send callbacks are
47 // supposed to be invoked sequentially.
48 ready_count_.store(args_.size());
49
50 std::vector<void*> arg_ptrs;
51 arg_ptrs.reserve(args_.size());
52 for (auto& arg : args_) {
53 arg_ptrs.push_back(arg.data());
54 }
55
56 std::vector<PjRtChunk> results;
57 std::vector<void*> result_ptrs;
58 results.reserve(result_channels_.size());
59 result_ptrs.reserve(result_channels_.size());
60 for (int i = 0; i < result_channels_.size(); ++i) {
61 const auto& host_shape = host_callback_.results.at(i).shape;
62 size_t host_size = ShapeUtil::ByteSizeOf(host_shape);
63 results.push_back(PjRtChunk::AllocateDefault(host_size));
64 result_ptrs.push_back(results.back().data());
65 }
66
67 auto status = host_callback_.callback(result_ptrs.data(), arg_ptrs.data());
68 // TODO(chky): Consider populating garbage data in results upon errors.
69
70 // Clear the arguments for this invocation. This won't race with next
71 // invocation as send callbacks are supposed to be invoked sequentially.
72 for (auto& arg : args_) {
73 arg = PjRtChunk{};
74 }
75
76 // Sending the results to recv callbacks if there is any. Note that after
77 // this point, this callback can be invoked again (e.g. in a loop) anytime.
78 for (int i = 0; i < result_channels_.size(); ++i) {
79 auto& result_channel = result_channels_[i];
80 result_channel->Push(std::move(results[i]));
81 }
82
83 return status;
84 }
85
Receive(int res_num,const PjRtTransferMetadata & metadata,CopyToDeviceStream & stream)86 void HostCallbackContext::Receive(int res_num,
87 const PjRtTransferMetadata& metadata,
88 CopyToDeviceStream& stream) {
89 auto& result_channel = result_channels_.at(res_num);
90 PjRtChunk chunk = result_channel->Pop();
91
92 const auto& host_shape = host_callback_.results.at(res_num).shape;
93 const auto& device_shape = metadata.device_shape;
94
95 auto statusor_linearized = host_memory_for_device_manager_->ToDeviceLayout(
96 chunk.data(), chunk.size(), host_shape, device_shape);
97 TF_CHECK_OK(stream.AddChunk(std::move(statusor_linearized).value()));
98 }
99
100 std::unique_ptr<HostCallbackContext>
CreateHostCallbackStateAndAppendSendRecvCallbacks(HostCallback host_callback,PjRtHostMemoryForDeviceManager * host_memory_for_device_manager,std::vector<SendCallback> & send_callbacks,std::vector<RecvCallback> & recv_callbacks)101 CreateHostCallbackStateAndAppendSendRecvCallbacks(
102 HostCallback host_callback,
103 PjRtHostMemoryForDeviceManager* host_memory_for_device_manager,
104 std::vector<SendCallback>& send_callbacks,
105 std::vector<RecvCallback>& recv_callbacks) {
106 auto context = std::make_unique<HostCallbackContext>(
107 std::move(host_callback), host_memory_for_device_manager);
108
109 const auto& hb = context->host_callback();
110 for (int arg_num = 0; arg_num < hb.operands.size(); ++arg_num) {
111 const auto& operand_info = hb.operands[arg_num];
112 send_callbacks.push_back(SendCallback{
113 /*channel_id=*/operand_info.channel_id,
114 /*callback=*/[arg_num, context = context.get()](
115 const PjRtTransferMetadata& metadata, PjRtChunk input,
116 size_t total_size_in_bytes, bool done) {
117 return context->OnSend(arg_num, metadata, std::move(input));
118 }});
119 }
120
121 for (int res_num = 0; res_num < hb.results.size(); ++res_num) {
122 const auto& result_info = hb.results[res_num];
123 recv_callbacks.push_back(
124 RecvCallback{/*channel_id=*/result_info.channel_id,
125 /*callback=*/[res_num, context = context.get()](
126 const PjRtTransferMetadata& metadata,
127 CopyToDeviceStream& stream) {
128 context->Receive(res_num, metadata, stream);
129 }});
130 }
131
132 return context;
133 }
134
135 } // namespace xla
136