1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
18
19 #include <optional>
20
21 #include "absl/container/node_hash_map.h"
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace xla {
25
26 // A simple LRU cache. Not thread-safe.
27 // Value must be copyable and moveable. The intent is that Value is typically
28 // a smart-pointer type.
29 template <typename Key, typename Value,
30 typename Hash = typename absl::node_hash_map<Key, Value>::hasher,
31 typename Eq = typename absl::node_hash_map<Key, Value>::key_equal>
32 class LRUCache {
33 private:
34 struct LRUListEntry {
35 LRUListEntry* next;
36 LRUListEntry* prev;
37 };
38
39 public:
40 // Multiple LRUCaches can share a LRU list, meaning that the capacity and
41 // eviction policy is shared. The user provides an LRU list
42 // to the cache constructor, and must ensure that it remains alive as long
43 // as the cache does.
44 class LRUList {
45 public:
LRUList(int capacity)46 explicit LRUList(int capacity) : capacity_(capacity) {
47 head_.next = &head_;
48 head_.prev = &head_;
49 }
~LRUList()50 ~LRUList() {
51 CHECK(head_.next == &head_);
52 CHECK(head_.prev == &head_);
53 }
54
55 LRUList(const LRUList&) = delete;
56 LRUList(LRUList&&) = delete;
57 LRUList& operator=(const LRUList&) = delete;
58 LRUList& operator=(LRUList&&) = delete;
59
Capacity()60 int Capacity() const { return capacity_; }
Size()61 int Size() const { return size_; }
62
63 void Clear();
64
65 private:
66 friend class LRUCache;
67 int capacity_;
68 int size_ = 0;
69
70 // Root of a circular doubly-linked list of entries, in order from least
71 // recently used to most recently used. An "empty" cache always contains
72 // this element in the LRU list.
73 LRUListEntry head_;
74 };
75
LRUCache(LRUList * lru_list)76 explicit LRUCache(LRUList* lru_list) : lru_list_(lru_list) {}
77 ~LRUCache();
78
79 LRUCache(const LRUCache&) = delete;
80 LRUCache(LRUCache&&) = delete;
81 LRUCache& operator=(const LRUCache&) = delete;
82 LRUCache& operator=(LRUCache&&) = delete;
83
84 // Returns the `value` associated with `key`. Creates a value with `factory`
85 // and inserts it if absent.
86 Value GetOrCreateIfAbsent(const Key& key,
87 const std::function<Value(const Key&)>& factory);
88
89 // Removes all entries from the cache.
90 void Clear();
91
Size()92 int Size() const { return entries_.size(); }
Capacity()93 int Capacity() const { return lru_list_->Capacity(); }
94
95 private:
96 LRUList* lru_list_;
97
98 struct Entry : public LRUListEntry {
99 Entry() = default;
100
101 // Pointer to the key in `entries_`. absl::node_hash_map<> promises
102 // pointer stability for keys.
103 const Key* key;
104 LRUCache* container;
105 std::optional<Value> value;
106 };
107
108 // We use `node_hash_map` because we want to guarantee pointer stability for
109 // keys and values.
110 absl::node_hash_map<Key, Entry, Hash, Eq> entries_;
111 };
112
113 template <typename Key, typename Value, typename Hash, typename Eq>
Clear()114 void LRUCache<Key, Value, Hash, Eq>::LRUList::Clear() {
115 while (head_.next != &head_) {
116 static_cast<Entry*>(head_.next)->container->Clear();
117 }
118 size_ = 0;
119 }
120
121 template <typename Key, typename Value, typename Hash, typename Eq>
Clear()122 void LRUCache<Key, Value, Hash, Eq>::Clear() {
123 for (auto& e : entries_) {
124 LRUListEntry* l = &e.second;
125 l->next->prev = l->prev;
126 l->prev->next = l->next;
127 --lru_list_->size_;
128 }
129 entries_.clear();
130 }
131
132 template <typename Key, typename Value, typename Hash, typename Eq>
~LRUCache()133 LRUCache<Key, Value, Hash, Eq>::~LRUCache() {
134 Clear();
135 }
136
137 template <typename Key, typename Value, typename Hash, typename Eq>
GetOrCreateIfAbsent(const Key & key,const std::function<Value (const Key &)> & factory)138 Value LRUCache<Key, Value, Hash, Eq>::GetOrCreateIfAbsent(
139 const Key& key, const std::function<Value(const Key&)>& factory) {
140 typename absl::node_hash_map<Key, Entry, Hash, Eq>::iterator it;
141 bool inserted;
142 std::tie(it, inserted) = entries_.try_emplace(key);
143 Entry& entry = it->second;
144 if (inserted) {
145 entry.key = &it->first;
146 entry.value = factory(*entry.key);
147 ++lru_list_->size_;
148 } else {
149 // Removes the entry from the LRU list, in preparation for adding it
150 // to the back of the list.
151 entry.prev->next = entry.next;
152 entry.next->prev = entry.prev;
153 }
154 // (Re-)adds entry to the back of the LRU list. Since it is now the
155 // most recently used element, it goes at the back.
156 LRUListEntry& lru_head = lru_list_->head_;
157 entry.container = this;
158 entry.prev = lru_head.prev;
159 entry.next = &lru_head;
160 lru_head.prev->next = &entry;
161 lru_head.prev = &entry;
162
163 Value v = *entry.value;
164
165 // Evict an LRU entry if we are over capacity.
166 if (lru_list_->size_ > lru_list_->capacity_) {
167 Entry* to_remove = static_cast<Entry*>(lru_head.next);
168 to_remove->next->prev = &lru_head;
169 lru_head.next = to_remove->next;
170 to_remove->container->entries_.erase(*to_remove->key);
171 --lru_list_->size_;
172 }
173 return v;
174 }
175
176 } // namespace xla
177
178 #endif // TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
179