1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h"
17
18 #include <atomic>
19 #include <functional>
20 #include <string>
21 #include <utility>
22
23 #include "absl/base/casts.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime
26
27 namespace xla {
28 namespace {
29
30 // Returns an AsyncValueRef<CpuEvent> that will be ready after all the async
31 // values in `events` are ready. If errors occurs, one of the errors will be
32 // propagated through the returned async value.
AfterAll(absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events)33 tfrt::AsyncValueRef<CpuEvent> AfterAll(
34 absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events) {
35 if (events.empty()) return tfrt::MakeAvailableAsyncValueRef<CpuEvent>();
36
37 struct State {
38 State(int count, tfrt::AsyncValueRef<CpuEvent> after_all)
39 : count(count), after_all(std::move(after_all)) {}
40 std::atomic<int> count;
41 tfrt::AsyncValueRef<CpuEvent> after_all;
42
43 absl::Mutex mutex;
44 std::string error_message;
45 };
46
47 auto after_all = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
48 auto* state = new State(events.size(), after_all);
49
50 for (auto& event : events) {
51 event.AndThen([state, event = event.AsPtr()]() {
52 if (event.IsError()) {
53 absl::MutexLock lock(&state->mutex);
54 state->error_message = event.GetError().message;
55 }
56
57 if (state->count.fetch_sub(1, std::memory_order_acq_rel) == 1) {
58 if (!state->error_message.empty()) {
59 state->after_all.SetError(state->error_message);
60 } else {
61 state->after_all.SetStateConcrete();
62 }
63 delete state;
64 }
65 });
66 }
67
68 return after_all;
69 }
70
71 } // namespace
72
TrackedTfrtCpuDeviceBuffer(bool is_tuple,absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>,4> buffers,absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>,4> definition_events,std::function<void ()> on_delete_callback)73 TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer(
74 bool is_tuple,
75 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
76 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,
77 std::function<void()> on_delete_callback)
78 : TrackedTfrtCpuDeviceBuffer(is_tuple, std::move(buffers),
79 AfterAll(definition_events),
80 std::move(on_delete_callback)) {}
81
TrackedTfrtCpuDeviceBuffer(bool is_tuple,absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>,4> buffers,tfrt::AsyncValueRef<CpuEvent> definition_event,std::function<void ()> on_delete_callback)82 TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer(
83 bool is_tuple,
84 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
85 tfrt::AsyncValueRef<CpuEvent> definition_event,
86 std::function<void()> on_delete_callback)
87 : is_tuple_(is_tuple),
88 buffers_(std::move(buffers)),
89 definition_event_(std::move(definition_event)),
90 on_delete_callback_(std::move(on_delete_callback)) {
91 DCHECK(definition_event_);
92 if (is_tuple) {
93 size_t index_table_byte_size = buffers_.size() * sizeof(void*);
94 // We assume tuple table allocations will not fail.
95 tuple_index_table_ =
96 MaybeOwningCpuMemory::AllocateShared(index_table_byte_size)
97 .ValueOrDie();
98 uintptr_t* index_table =
99 reinterpret_cast<uintptr_t*>(tuple_index_table_->data());
100 for (int i = 0; i < buffers_.size(); ++i) {
101 index_table[i] = absl::bit_cast<uintptr_t>(buffers_[i]->data());
102 }
103 }
104 }
105
~TrackedTfrtCpuDeviceBuffer()106 TrackedTfrtCpuDeviceBuffer::~TrackedTfrtCpuDeviceBuffer() {
107 ReleaseDeviceMemory();
108 if (on_delete_callback_) {
109 on_delete_callback_();
110 }
111 }
112
Buffer(const ShapeIndex & shape_index)113 std::shared_ptr<MaybeOwningCpuMemory> TrackedTfrtCpuDeviceBuffer::Buffer(
114 const ShapeIndex& shape_index) {
115 if (shape_index.empty()) {
116 // shape_index={}
117 if (is_tuple_) return tuple_index_table_;
118 return buffers_[0];
119 }
120 // shape_index={i}
121 CHECK(is_tuple_);
122 CHECK_EQ(shape_index.size(), 1) << "nested tuple not supported";
123 return buffers_[shape_index[0]];
124 }
125
AddUsageEvents(absl::Span<tfrt::AsyncValueRef<CpuEvent>> events)126 void TrackedTfrtCpuDeviceBuffer::AddUsageEvents(
127 absl::Span<tfrt::AsyncValueRef<CpuEvent>> events) {
128 // Periodically remove available usage events to prevent memory blowup.
129 if (usage_events_.size() >= 1024) {
130 int i = 0;
131 while (i < usage_events_.size()) {
132 auto& event = usage_events_.at(i);
133 if (event.IsAvailable()) {
134 using std::swap;
135 swap(event, usage_events_.back());
136 usage_events_.pop_back();
137 continue;
138 }
139 ++i;
140 }
141 }
142 for (auto& ev : events) {
143 usage_events_.push_back(std::move(ev));
144 }
145 }
146
147 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4>
LockUseAndTransferUsageEvents()148 TrackedTfrtCpuDeviceBuffer::LockUseAndTransferUsageEvents() {
149 return std::move(usage_events_);
150 }
151
ReleaseDeviceMemory()152 void TrackedTfrtCpuDeviceBuffer::ReleaseDeviceMemory() {
153 tuple_index_table_.reset();
154 buffers_.clear();
155 definition_event_.reset();
156 usage_events_.clear();
157 }
158
159 } // namespace xla
160