xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/allocation_tracker.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/transfer_manager.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31 
32 namespace xla {
33 
Register(ScopedShapedBuffer shaped_buffer,const std::string & tag)34 StatusOr<GlobalDataHandle> AllocationTracker::Register(
35     ScopedShapedBuffer shaped_buffer, const std::string& tag) {
36   absl::MutexLock lock(&mutex_);
37   VLOG(2) << "Register";
38   std::vector<ScopedShapedBuffer> replicated_buffers;
39   replicated_buffers.emplace_back(std::move(shaped_buffer));
40   return RegisterInternal(std::move(replicated_buffers), tag);
41 }
42 
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const std::string & tag)43 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
44     std::vector<ScopedShapedBuffer> replicated_buffers,
45     const std::string& tag) {
46   absl::MutexLock lock(&mutex_);
47   VLOG(2) << "RegisterReplicatedBuffers";
48   return RegisterInternal(std::move(replicated_buffers), tag);
49 }
50 
51 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
52 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
53 // unmodified.
ReleaseIfScopedShapedBuffer(ShapedBuffer b)54 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b)55 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
56   return b.release();
57 }
58 
59 template <typename ShapedBufferTy>
RegisterInternal(std::vector<ShapedBufferTy> replicated_buffers,const std::string & tag)60 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
61     std::vector<ShapedBufferTy> replicated_buffers, const std::string& tag) {
62   static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
63                     std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
64                 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
65   VLOG(2) << "RegisterInternal("
66           << "tag: \"" << tag << "\" with " << replicated_buffers.size()
67           << " shaped_buffers.";
68 
69   int64_t handle = next_handle_++;
70   for (auto& shaped_buffer : replicated_buffers) {
71     std::vector<ShapeIndex> shape_indices;
72     ShapeUtil::ForEachSubshape(
73         shaped_buffer.on_device_shape(),
74         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
75           shape_indices.push_back(index);
76         });
77     // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
78     // them.
79     for (const ShapeIndex& index : shape_indices) {
80       AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
81                                        shaped_buffer.device_ordinal());
82     }
83     // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
84     // into a regular ShapedBuffer, which is stored in
85     // handle_to_shaped_buffers_.
86     handle_to_shaped_buffers_[handle].emplace_back(
87         std::make_unique<ShapedBuffer>(
88             ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
89   }
90 
91   GlobalDataHandle result;
92   result.set_handle(handle);
93   VLOG(2) << "handle: " << handle;
94   return result;
95 }
96 
Unregister(const GlobalDataHandle & data)97 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
98   absl::MutexLock lock(&mutex_);
99   VLOG(2) << "Unregister("
100           << "handle: " << data.handle() << ")";
101   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
102                       ResolveInternal(data));
103   for (const auto& shaped_buffer : replicated_buffers) {
104     std::vector<ShapeIndex> shape_indices;
105     ShapeUtil::ForEachSubshape(
106         shaped_buffer->on_device_shape(),
107         [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
108           shape_indices.push_back(index);
109         });
110     for (const ShapeIndex& index : shape_indices) {
111       TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
112                                            shaped_buffer->device_ordinal()));
113     }
114   }
115   // Keep a nullptr as a tombstone for unregistered handles. This enables
116   // better error messages. That is, "handle has been deallocated" versus
117   // "handle does not exist".
118   auto it = handle_to_shaped_buffers_.find(data.handle());
119   if (it == handle_to_shaped_buffers_.end()) {
120     return NotFound("no allocation record for global data handle: %d",
121                     data.handle());
122   }
123   for (auto& shaped_buffer : it->second) {
124     shaped_buffer.reset();
125   }
126   return OkStatus();
127 }
128 
DeconstructTuple(const GlobalDataHandle & data)129 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
130     const GlobalDataHandle& data) {
131   absl::MutexLock lock(&mutex_);
132 
133   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
134                       ResolveInternal(data));
135   // We only need to care about replica id 0 here, since the GlobalDataHandle is
136   // the same for all buffers across replicas.
137   const ShapedBuffer* shaped_buffer = replicated_buffers[0];
138   if (!shaped_buffer->on_device_shape().IsTuple()) {
139     return InvalidArgument("global data handle %d is not a tuple",
140                            data.handle());
141   }
142 
143   if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
144     return Unimplemented("Deconstructing nested tuples is not implemented.");
145   }
146 
147   std::vector<GlobalDataHandle> element_handles;
148   const auto n = ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
149   element_handles.reserve(n);
150   for (int i = 0; i < n; ++i) {
151     auto element_buffer = ShapedBuffer(
152         ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
153         shaped_buffer->device_ordinal());
154     element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
155                               /*index=*/{});
156     std::vector<ShapedBuffer> replicated_buffers;
157     replicated_buffers.push_back(std::move(element_buffer));
158     TF_ASSIGN_OR_RETURN(
159         GlobalDataHandle element_handle,
160         RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
161 
162     element_handles.push_back(element_handle);
163   }
164   return std::move(element_handles);
165 }
166 
Resolve(const GlobalDataHandle & data) const167 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
168     const GlobalDataHandle& data) const {
169   absl::MutexLock lock(&mutex_);
170   return AllocationTracker::ResolveInternal(data);
171 }
172 
ResolveForReplica(const GlobalDataHandle & data,int replica_id) const173 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
174     const GlobalDataHandle& data, int replica_id) const {
175   absl::MutexLock lock(&mutex_);
176   TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
177                       ResolveInternal(data));
178   if (replica_id >= replicated_buffers.size()) {
179     return InvalidArgument(
180         "Requesting buffer for replica %d, but found buffers only for %lu "
181         "replicas.",
182         replica_id, replicated_buffers.size());
183   }
184   return replicated_buffers[replica_id];
185 }
186 
ResolveInternal(const GlobalDataHandle & data) const187 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
188     const GlobalDataHandle& data) const {
189   VLOG(2) << "resolve:" << data.handle();
190   auto it = handle_to_shaped_buffers_.find(data.handle());
191   if (it == handle_to_shaped_buffers_.end()) {
192     return NotFound("no allocation record for global data handle: %d",
193                     data.handle());
194   }
195   std::vector<const ShapedBuffer*> replicated_buffers;
196   for (const auto& shaped_buffer : it->second) {
197     if (shaped_buffer == nullptr) {
198       return InvalidArgument("global data handle %d was previously deallocated",
199                              data.handle());
200     }
201     replicated_buffers.push_back(shaped_buffer.get());
202   }
203 
204   return replicated_buffers;
205 }
206 
AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)207 void AllocationTracker::AddAllocationOrIncrementRefCount(
208     se::DeviceMemoryBase device_memory, int device_ordinal) {
209   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
210   auto it = allocation_map.find(device_memory.opaque());
211   if (it == allocation_map.end()) {
212     allocation_map[device_memory.opaque()] = {
213         se::OwningDeviceMemory(device_memory, device_ordinal,
214                                backend_->memory_allocator()),
215         /*ref_count=*/1};
216   } else {
217     it->second.ref_count++;
218   }
219 }
220 
DecrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)221 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
222                                             int device_ordinal) {
223   AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
224   auto it = allocation_map.find(device_memory.opaque());
225   TF_RET_CHECK(it != allocation_map.end());
226   Allocation& allocation = it->second;
227   TF_RET_CHECK(allocation.ref_count >= 1);
228   if (allocation.ref_count == 1) {
229     TF_RETURN_IF_ERROR(allocation.device_memory.Free());
230     allocation_map.erase(it);
231   } else {
232     allocation.ref_count--;
233   }
234   return OkStatus();
235 }
236 
237 }  // namespace xla
238