1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/allocation_tracker.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/compiler/xla/map_util.h"
23 #include "tensorflow/compiler/xla/service/transfer_manager.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/types.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/stream_executor/device_memory_allocator.h"
31
32 namespace xla {
33
Register(ScopedShapedBuffer shaped_buffer,const std::string & tag)34 StatusOr<GlobalDataHandle> AllocationTracker::Register(
35 ScopedShapedBuffer shaped_buffer, const std::string& tag) {
36 absl::MutexLock lock(&mutex_);
37 VLOG(2) << "Register";
38 std::vector<ScopedShapedBuffer> replicated_buffers;
39 replicated_buffers.emplace_back(std::move(shaped_buffer));
40 return RegisterInternal(std::move(replicated_buffers), tag);
41 }
42
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const std::string & tag)43 StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
44 std::vector<ScopedShapedBuffer> replicated_buffers,
45 const std::string& tag) {
46 absl::MutexLock lock(&mutex_);
47 VLOG(2) << "RegisterReplicatedBuffers";
48 return RegisterInternal(std::move(replicated_buffers), tag);
49 }
50
51 // ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
52 // b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
53 // unmodified.
ReleaseIfScopedShapedBuffer(ShapedBuffer b)54 static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b)55 static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
56 return b.release();
57 }
58
59 template <typename ShapedBufferTy>
RegisterInternal(std::vector<ShapedBufferTy> replicated_buffers,const std::string & tag)60 StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
61 std::vector<ShapedBufferTy> replicated_buffers, const std::string& tag) {
62 static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
63 std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
64 "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
65 VLOG(2) << "RegisterInternal("
66 << "tag: \"" << tag << "\" with " << replicated_buffers.size()
67 << " shaped_buffers.";
68
69 int64_t handle = next_handle_++;
70 for (auto& shaped_buffer : replicated_buffers) {
71 std::vector<ShapeIndex> shape_indices;
72 ShapeUtil::ForEachSubshape(
73 shaped_buffer.on_device_shape(),
74 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
75 shape_indices.push_back(index);
76 });
77 // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
78 // them.
79 for (const ShapeIndex& index : shape_indices) {
80 AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
81 shaped_buffer.device_ordinal());
82 }
83 // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
84 // into a regular ShapedBuffer, which is stored in
85 // handle_to_shaped_buffers_.
86 handle_to_shaped_buffers_[handle].emplace_back(
87 std::make_unique<ShapedBuffer>(
88 ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
89 }
90
91 GlobalDataHandle result;
92 result.set_handle(handle);
93 VLOG(2) << "handle: " << handle;
94 return result;
95 }
96
Unregister(const GlobalDataHandle & data)97 Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
98 absl::MutexLock lock(&mutex_);
99 VLOG(2) << "Unregister("
100 << "handle: " << data.handle() << ")";
101 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
102 ResolveInternal(data));
103 for (const auto& shaped_buffer : replicated_buffers) {
104 std::vector<ShapeIndex> shape_indices;
105 ShapeUtil::ForEachSubshape(
106 shaped_buffer->on_device_shape(),
107 [&shape_indices](const Shape& /*subshape*/, const ShapeIndex& index) {
108 shape_indices.push_back(index);
109 });
110 for (const ShapeIndex& index : shape_indices) {
111 TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index),
112 shaped_buffer->device_ordinal()));
113 }
114 }
115 // Keep a nullptr as a tombstone for unregistered handles. This enables
116 // better error messages. That is, "handle has been deallocated" versus
117 // "handle does not exist".
118 auto it = handle_to_shaped_buffers_.find(data.handle());
119 if (it == handle_to_shaped_buffers_.end()) {
120 return NotFound("no allocation record for global data handle: %d",
121 data.handle());
122 }
123 for (auto& shaped_buffer : it->second) {
124 shaped_buffer.reset();
125 }
126 return OkStatus();
127 }
128
DeconstructTuple(const GlobalDataHandle & data)129 StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
130 const GlobalDataHandle& data) {
131 absl::MutexLock lock(&mutex_);
132
133 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
134 ResolveInternal(data));
135 // We only need to care about replica id 0 here, since the GlobalDataHandle is
136 // the same for all buffers across replicas.
137 const ShapedBuffer* shaped_buffer = replicated_buffers[0];
138 if (!shaped_buffer->on_device_shape().IsTuple()) {
139 return InvalidArgument("global data handle %d is not a tuple",
140 data.handle());
141 }
142
143 if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) {
144 return Unimplemented("Deconstructing nested tuples is not implemented.");
145 }
146
147 std::vector<GlobalDataHandle> element_handles;
148 const auto n = ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape());
149 element_handles.reserve(n);
150 for (int i = 0; i < n; ++i) {
151 auto element_buffer = ShapedBuffer(
152 ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i),
153 shaped_buffer->device_ordinal());
154 element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
155 /*index=*/{});
156 std::vector<ShapedBuffer> replicated_buffers;
157 replicated_buffers.push_back(std::move(element_buffer));
158 TF_ASSIGN_OR_RETURN(
159 GlobalDataHandle element_handle,
160 RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
161
162 element_handles.push_back(element_handle);
163 }
164 return std::move(element_handles);
165 }
166
Resolve(const GlobalDataHandle & data) const167 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::Resolve(
168 const GlobalDataHandle& data) const {
169 absl::MutexLock lock(&mutex_);
170 return AllocationTracker::ResolveInternal(data);
171 }
172
ResolveForReplica(const GlobalDataHandle & data,int replica_id) const173 StatusOr<const ShapedBuffer*> AllocationTracker::ResolveForReplica(
174 const GlobalDataHandle& data, int replica_id) const {
175 absl::MutexLock lock(&mutex_);
176 TF_ASSIGN_OR_RETURN(std::vector<const ShapedBuffer*> replicated_buffers,
177 ResolveInternal(data));
178 if (replica_id >= replicated_buffers.size()) {
179 return InvalidArgument(
180 "Requesting buffer for replica %d, but found buffers only for %lu "
181 "replicas.",
182 replica_id, replicated_buffers.size());
183 }
184 return replicated_buffers[replica_id];
185 }
186
ResolveInternal(const GlobalDataHandle & data) const187 StatusOr<std::vector<const ShapedBuffer*>> AllocationTracker::ResolveInternal(
188 const GlobalDataHandle& data) const {
189 VLOG(2) << "resolve:" << data.handle();
190 auto it = handle_to_shaped_buffers_.find(data.handle());
191 if (it == handle_to_shaped_buffers_.end()) {
192 return NotFound("no allocation record for global data handle: %d",
193 data.handle());
194 }
195 std::vector<const ShapedBuffer*> replicated_buffers;
196 for (const auto& shaped_buffer : it->second) {
197 if (shaped_buffer == nullptr) {
198 return InvalidArgument("global data handle %d was previously deallocated",
199 data.handle());
200 }
201 replicated_buffers.push_back(shaped_buffer.get());
202 }
203
204 return replicated_buffers;
205 }
206
AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)207 void AllocationTracker::AddAllocationOrIncrementRefCount(
208 se::DeviceMemoryBase device_memory, int device_ordinal) {
209 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
210 auto it = allocation_map.find(device_memory.opaque());
211 if (it == allocation_map.end()) {
212 allocation_map[device_memory.opaque()] = {
213 se::OwningDeviceMemory(device_memory, device_ordinal,
214 backend_->memory_allocator()),
215 /*ref_count=*/1};
216 } else {
217 it->second.ref_count++;
218 }
219 }
220
DecrementRefCount(se::DeviceMemoryBase device_memory,int device_ordinal)221 Status AllocationTracker::DecrementRefCount(se::DeviceMemoryBase device_memory,
222 int device_ordinal) {
223 AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal];
224 auto it = allocation_map.find(device_memory.opaque());
225 TF_RET_CHECK(it != allocation_map.end());
226 Allocation& allocation = it->second;
227 TF_RET_CHECK(allocation.ref_count >= 1);
228 if (allocation.ref_count == 1) {
229 TF_RETURN_IF_ERROR(allocation.device_memory.Free());
230 allocation_map.erase(it);
231 } else {
232 allocation.ref_count--;
233 }
234 return OkStatus();
235 }
236
237 } // namespace xla
238