xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/pjrt/tracked_tfrt_cpu_device_buffer.h"
17 
18 #include <atomic>
19 #include <functional>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/base/casts.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
26 
27 namespace xla {
28 namespace {
29 
30 // Returns an AsyncValueRef<CpuEvent> that will be ready after all the async
31 // values in `events` are ready. If errors occurs, one of the errors will be
32 // propagated through the returned async value.
AfterAll(absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events)33 tfrt::AsyncValueRef<CpuEvent> AfterAll(
34     absl::Span<const tfrt::AsyncValueRef<CpuEvent>> events) {
35   if (events.empty()) return tfrt::MakeAvailableAsyncValueRef<CpuEvent>();
36 
37   struct State {
38     State(int count, tfrt::AsyncValueRef<CpuEvent> after_all)
39         : count(count), after_all(std::move(after_all)) {}
40     std::atomic<int> count;
41     tfrt::AsyncValueRef<CpuEvent> after_all;
42 
43     absl::Mutex mutex;
44     std::string error_message;
45   };
46 
47   auto after_all = tfrt::MakeConstructedAsyncValueRef<CpuEvent>();
48   auto* state = new State(events.size(), after_all);
49 
50   for (auto& event : events) {
51     event.AndThen([state, event = event.AsPtr()]() {
52       if (event.IsError()) {
53         absl::MutexLock lock(&state->mutex);
54         state->error_message = event.GetError().message;
55       }
56 
57       if (state->count.fetch_sub(1, std::memory_order_acq_rel) == 1) {
58         if (!state->error_message.empty()) {
59           state->after_all.SetError(state->error_message);
60         } else {
61           state->after_all.SetStateConcrete();
62         }
63         delete state;
64       }
65     });
66   }
67 
68   return after_all;
69 }
70 
71 }  // namespace
72 
TrackedTfrtCpuDeviceBuffer(bool is_tuple,absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>,4> buffers,absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>,4> definition_events,std::function<void ()> on_delete_callback)73 TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer(
74     bool is_tuple,
75     absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
76     absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,
77     std::function<void()> on_delete_callback)
78     : TrackedTfrtCpuDeviceBuffer(is_tuple, std::move(buffers),
79                                  AfterAll(definition_events),
80                                  std::move(on_delete_callback)) {}
81 
TrackedTfrtCpuDeviceBuffer(bool is_tuple,absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>,4> buffers,tfrt::AsyncValueRef<CpuEvent> definition_event,std::function<void ()> on_delete_callback)82 TrackedTfrtCpuDeviceBuffer::TrackedTfrtCpuDeviceBuffer(
83     bool is_tuple,
84     absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
85     tfrt::AsyncValueRef<CpuEvent> definition_event,
86     std::function<void()> on_delete_callback)
87     : is_tuple_(is_tuple),
88       buffers_(std::move(buffers)),
89       definition_event_(std::move(definition_event)),
90       on_delete_callback_(std::move(on_delete_callback)) {
91   DCHECK(definition_event_);
92   if (is_tuple) {
93     size_t index_table_byte_size = buffers_.size() * sizeof(void*);
94     // We assume tuple table allocations will not fail.
95     tuple_index_table_ =
96         MaybeOwningCpuMemory::AllocateShared(index_table_byte_size)
97             .ValueOrDie();
98     uintptr_t* index_table =
99         reinterpret_cast<uintptr_t*>(tuple_index_table_->data());
100     for (int i = 0; i < buffers_.size(); ++i) {
101       index_table[i] = absl::bit_cast<uintptr_t>(buffers_[i]->data());
102     }
103   }
104 }
105 
~TrackedTfrtCpuDeviceBuffer()106 TrackedTfrtCpuDeviceBuffer::~TrackedTfrtCpuDeviceBuffer() {
107   ReleaseDeviceMemory();
108   if (on_delete_callback_) {
109     on_delete_callback_();
110   }
111 }
112 
Buffer(const ShapeIndex & shape_index)113 std::shared_ptr<MaybeOwningCpuMemory> TrackedTfrtCpuDeviceBuffer::Buffer(
114     const ShapeIndex& shape_index) {
115   if (shape_index.empty()) {
116     // shape_index={}
117     if (is_tuple_) return tuple_index_table_;
118     return buffers_[0];
119   }
120   // shape_index={i}
121   CHECK(is_tuple_);
122   CHECK_EQ(shape_index.size(), 1) << "nested tuple not supported";
123   return buffers_[shape_index[0]];
124 }
125 
AddUsageEvents(absl::Span<tfrt::AsyncValueRef<CpuEvent>> events)126 void TrackedTfrtCpuDeviceBuffer::AddUsageEvents(
127     absl::Span<tfrt::AsyncValueRef<CpuEvent>> events) {
128   // Periodically remove available usage events to prevent memory blowup.
129   if (usage_events_.size() >= 1024) {
130     int i = 0;
131     while (i < usage_events_.size()) {
132       auto& event = usage_events_.at(i);
133       if (event.IsAvailable()) {
134         using std::swap;
135         swap(event, usage_events_.back());
136         usage_events_.pop_back();
137         continue;
138       }
139       ++i;
140     }
141   }
142   for (auto& ev : events) {
143     usage_events_.push_back(std::move(ev));
144   }
145 }
146 
147 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4>
LockUseAndTransferUsageEvents()148 TrackedTfrtCpuDeviceBuffer::LockUseAndTransferUsageEvents() {
149   return std::move(usage_events_);
150 }
151 
ReleaseDeviceMemory()152 void TrackedTfrtCpuDeviceBuffer::ReleaseDeviceMemory() {
153   tuple_index_table_.reset();
154   buffers_.clear();
155   definition_event_.reset();
156   usage_events_.clear();
157 }
158 
159 }  // namespace xla
160