xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/lru_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
18 
19 #include <optional>
20 
21 #include "absl/container/node_hash_map.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace xla {
25 
26 // A simple LRU cache. Not thread-safe.
27 // Value must be copyable and moveable. The intent is that Value is typically
28 // a smart-pointer type.
29 template <typename Key, typename Value,
30           typename Hash = typename absl::node_hash_map<Key, Value>::hasher,
31           typename Eq = typename absl::node_hash_map<Key, Value>::key_equal>
32 class LRUCache {
33  private:
34   struct LRUListEntry {
35     LRUListEntry* next;
36     LRUListEntry* prev;
37   };
38 
39  public:
40   // Multiple LRUCaches can share a LRU list, meaning that the capacity and
41   // eviction policy is shared. The user provides an LRU list
42   // to the cache constructor, and must ensure that it remains alive as long
43   // as the cache does.
44   class LRUList {
45    public:
LRUList(int capacity)46     explicit LRUList(int capacity) : capacity_(capacity) {
47       head_.next = &head_;
48       head_.prev = &head_;
49     }
~LRUList()50     ~LRUList() {
51       CHECK(head_.next == &head_);
52       CHECK(head_.prev == &head_);
53     }
54 
55     LRUList(const LRUList&) = delete;
56     LRUList(LRUList&&) = delete;
57     LRUList& operator=(const LRUList&) = delete;
58     LRUList& operator=(LRUList&&) = delete;
59 
Capacity()60     int Capacity() const { return capacity_; }
Size()61     int Size() const { return size_; }
62 
63     void Clear();
64 
65    private:
66     friend class LRUCache;
67     int capacity_;
68     int size_ = 0;
69 
70     // Root of a circular doubly-linked list of entries, in order from least
71     // recently used to most recently used. An "empty" cache always contains
72     // this element in the LRU list.
73     LRUListEntry head_;
74   };
75 
LRUCache(LRUList * lru_list)76   explicit LRUCache(LRUList* lru_list) : lru_list_(lru_list) {}
77   ~LRUCache();
78 
79   LRUCache(const LRUCache&) = delete;
80   LRUCache(LRUCache&&) = delete;
81   LRUCache& operator=(const LRUCache&) = delete;
82   LRUCache& operator=(LRUCache&&) = delete;
83 
84   // Returns the `value` associated with `key`. Creates a value with `factory`
85   // and inserts it if absent.
86   Value GetOrCreateIfAbsent(const Key& key,
87                             const std::function<Value(const Key&)>& factory);
88 
89   // Removes all entries from the cache.
90   void Clear();
91 
Size()92   int Size() const { return entries_.size(); }
Capacity()93   int Capacity() const { return lru_list_->Capacity(); }
94 
95  private:
96   LRUList* lru_list_;
97 
98   struct Entry : public LRUListEntry {
99     Entry() = default;
100 
101     // Pointer to the key in `entries_`. absl::node_hash_map<> promises
102     // pointer stability for keys.
103     const Key* key;
104     LRUCache* container;
105     std::optional<Value> value;
106   };
107 
108   // We use `node_hash_map` because we want to guarantee pointer stability for
109   // keys and values.
110   absl::node_hash_map<Key, Entry, Hash, Eq> entries_;
111 };
112 
113 template <typename Key, typename Value, typename Hash, typename Eq>
Clear()114 void LRUCache<Key, Value, Hash, Eq>::LRUList::Clear() {
115   while (head_.next != &head_) {
116     static_cast<Entry*>(head_.next)->container->Clear();
117   }
118   size_ = 0;
119 }
120 
121 template <typename Key, typename Value, typename Hash, typename Eq>
Clear()122 void LRUCache<Key, Value, Hash, Eq>::Clear() {
123   for (auto& e : entries_) {
124     LRUListEntry* l = &e.second;
125     l->next->prev = l->prev;
126     l->prev->next = l->next;
127     --lru_list_->size_;
128   }
129   entries_.clear();
130 }
131 
132 template <typename Key, typename Value, typename Hash, typename Eq>
~LRUCache()133 LRUCache<Key, Value, Hash, Eq>::~LRUCache() {
134   Clear();
135 }
136 
137 template <typename Key, typename Value, typename Hash, typename Eq>
GetOrCreateIfAbsent(const Key & key,const std::function<Value (const Key &)> & factory)138 Value LRUCache<Key, Value, Hash, Eq>::GetOrCreateIfAbsent(
139     const Key& key, const std::function<Value(const Key&)>& factory) {
140   typename absl::node_hash_map<Key, Entry, Hash, Eq>::iterator it;
141   bool inserted;
142   std::tie(it, inserted) = entries_.try_emplace(key);
143   Entry& entry = it->second;
144   if (inserted) {
145     entry.key = &it->first;
146     entry.value = factory(*entry.key);
147     ++lru_list_->size_;
148   } else {
149     // Removes the entry from the LRU list, in preparation for adding it
150     // to the back of the list.
151     entry.prev->next = entry.next;
152     entry.next->prev = entry.prev;
153   }
154   // (Re-)adds entry to the back of the LRU list. Since it is now the
155   // most recently used element, it goes at the back.
156   LRUListEntry& lru_head = lru_list_->head_;
157   entry.container = this;
158   entry.prev = lru_head.prev;
159   entry.next = &lru_head;
160   lru_head.prev->next = &entry;
161   lru_head.prev = &entry;
162 
163   Value v = *entry.value;
164 
165   // Evict an LRU entry if we are over capacity.
166   if (lru_list_->size_ > lru_list_->capacity_) {
167     Entry* to_remove = static_cast<Entry*>(lru_head.next);
168     to_remove->next->prev = &lru_head;
169     lru_head.next = to_remove->next;
170     to_remove->container->entries_.erase(*to_remove->key);
171     --lru_list_->size_;
172   }
173   return v;
174 }
175 
176 }  // namespace xla
177 
178 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_LRU_CACHE_H_
179