xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/platform/casts.h"
20 #include "tensorflow/core/tpu/kernels/tpu_util.h"
21 #include "tensorflow/core/tpu/tpu_api.h"
22 
23 namespace tensorflow {
24 namespace tpu {
25 
RefHolder(TpuCompilationCacheInterface * parent)26 TpuCompilationCacheInterface::RefHolder::RefHolder(
27     TpuCompilationCacheInterface* parent)
28     : parent_(parent) {
29   // Hold a reference to the parent until the holder is discarded.
30   parent_->Ref();
31 }
32 
~RefHolder()33 TpuCompilationCacheInterface::RefHolder::~RefHolder() {
34   parent_->DiscardEntryRefs(entries_);
35   // Release our reference to the parent.
36   parent_->Unref();
37 }
38 
AddRef(CompiledSubgraph * entry)39 void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
40   entries_.push_back(entry);
41 }
42 
DebugString() const43 std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
44   return "TpuCompilationCacheRefHolder";
45 }
46 
CompilationCacheEntryRef()47 CompilationCacheEntryRef::CompilationCacheEntryRef()
48     : parent_(nullptr), entry_(nullptr), index_(0) {}
49 
CompilationCacheEntryRef(TpuCompilationCacheInterface * parent,CompiledSubgraph * entry,int index)50 CompilationCacheEntryRef::CompilationCacheEntryRef(
51     TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
52     : parent_(parent), entry_(entry), index_(index) {
53   if (entry_ == nullptr) {
54     return;
55   }
56   if (entry_->main_entry == nullptr) {
57     entry_->Ref();
58   } else {
59     // This is a sharding/unsharding entry nested in a main entry. Only
60     // refcount the main entry.
61     entry_->main_entry->Ref();
62   }
63 }
64 
~CompilationCacheEntryRef()65 CompilationCacheEntryRef::~CompilationCacheEntryRef() {
66   if (entry_ == nullptr) {
67     return;
68   }
69   if (entry_->main_entry == nullptr) {
70     parent_->DiscardEntryRefs({entry_});
71   } else {
72     parent_->DiscardEntryRefs({entry_->main_entry});
73   }
74 }
75 
get()76 TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
77   if (entry_ == nullptr) {
78     // Create an empty entry if the entry is nullptr. This corresponds to
79     // non-existing sharding/unsharding entries.
80     return TpuCompilationCacheEntry();
81   }
82 
83   return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
84 }
85 
ToSubEntryRef(CompilationCacheFetchTarget fetch_target)86 Status CompilationCacheEntryRef::ToSubEntryRef(
87     CompilationCacheFetchTarget fetch_target) {
88   CompiledSubgraph* target = nullptr;
89   switch (fetch_target) {
90     case CompilationCacheFetchTarget::MAIN:
91       target = entry_;
92       break;
93     case CompilationCacheFetchTarget::SHARDING:
94       target = entry_->sharding_entry.get();
95       break;
96     case CompilationCacheFetchTarget::UNSHARDING:
97       target = entry_->unsharding_entry.get();
98       break;
99     default:
100       return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
101   }
102 
103   if (target == nullptr) {
104     // Cache entry does not have an unsharding subentry. Unref and replace
105     // with nullptr.
106     parent_->DiscardEntryRefs({entry_});
107   }
108   // Otherwise, since the refcount is always on the main entry, we don't
109   // need ref/unref.
110   entry_ = target;
111   return OkStatus();
112 }
113 
TpuCompilationCacheInterface(int64_t max_cache_size)114 TpuCompilationCacheInterface::TpuCompilationCacheInterface(
115     int64_t max_cache_size)
116     : max_cache_size_(max_cache_size) {
117   CHECK_GE(max_cache_size_, 0);
118   VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
119 }
120 
~TpuCompilationCacheInterface()121 TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
122   VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
123   // A buggy client may be holding onto a reference, or a client might have
124   // crashed while holding onto a reference. In either case, discard all
125   // outstanding client references to avoid leaking storage.
126   for (const auto& entry : entries_by_uid_) {
127     while (entry.second->external_references > 0) {
128       Status s = Release(entry.first);
129       CHECK(s.ok());
130     }
131   }
132   while (!entries_by_last_use_.empty()) {
133     UnloadAndDestroy(MarkOldestEntryForEviction());
134   }
135   // By the time the cache is deleted all reference holders should have already
136   // been deleted, since they were holding references to the cache. So all
137   // entries should be gone at this point.
138   CHECK_EQ(cache_.size(), 0);
139   CHECK_EQ(entries_by_uid_.size(), 0);
140   CHECK_EQ(entries_by_proto_key_.size(), 0);
141   CHECK_EQ(cache_size_, 0);
142   CHECK_EQ(marked_for_eviction_size_, 0);
143 }
144 
MarkEntryForEviction(int64_t subgraph_uid)145 Status TpuCompilationCacheInterface::MarkEntryForEviction(
146     int64_t subgraph_uid) {
147   profiler::TraceMe key_release_traceme(
148       "TPU compilation cache possibly evict uid",
149       /*level=*/2);
150   CompiledSubgraph* deleted_entry = nullptr;
151   {
152     absl::MutexLock lock(&mu_);
153     auto iter = entries_by_uid_.find(subgraph_uid);
154     if (iter == entries_by_uid_.end()) {
155       // If already evicted, return ok.
156       return OkStatus();
157     }
158 
159     // Mark entry for eviction.
160     CompiledSubgraph* subgraph_to_evict = iter->second;
161     // If there are external references, should not use this API.
162     if (subgraph_to_evict->external_references != 0) {
163       return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
164                               " external_references greater than zero. Should "
165                               "use TpuCompilationCacheInterface::Release.");
166     }
167 
168     VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key
169             << " for eviction. Debug string: "
170             << subgraph_to_evict->cache_entry_debug_string;
171     entries_by_last_use_.erase(subgraph_to_evict->last_use);
172     cache_size_ -= subgraph_to_evict->total_size;
173     marked_for_eviction_size_ += subgraph_to_evict->total_size;
174 
175     // Evict if refcount exactly one, otherwise only discard cache's reference
176     // to the entry while the actual eviction will happen when refholder's
177     // references go away.
178     deleted_entry = DiscardEntryRef(subgraph_to_evict);
179 
180     VLOG(1) << "After possibly evicting entry " << subgraph_uid
181             << " refs cache is " << cache_.size() << " entries ("
182             << cache_size_ + marked_for_eviction_size_
183             << " bytes), marked for eviction "
184             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
185             << marked_for_eviction_size_ << " bytes).";
186   }
187 
188   // Unload from device cache if entry is evicted from host cache.
189   UnloadAndDestroy(deleted_entry);
190   return OkStatus();
191 }
192 
Release(int64_t subgraph_uid)193 Status TpuCompilationCacheInterface::Release(int64_t subgraph_uid) {
194   profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
195                                         /*level=*/2);
196 
197   CompiledSubgraph* deleted_entry = nullptr;
198   {
199     absl::MutexLock lock(&mu_);
200     auto iter = entries_by_uid_.find(subgraph_uid);
201 
202     if (iter == entries_by_uid_.end()) {
203       return errors::NotFound("No cache entry found for uid ", subgraph_uid);
204     }
205 
206     CHECK_GT(iter->second->external_references, 0);
207     --iter->second->external_references;
208 
209     deleted_entry = DiscardEntryRef(iter->second);
210 
211     VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
212             << cache_.size() << " entries ("
213             << cache_size_ + marked_for_eviction_size_
214             << " bytes), marked for eviction "
215             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
216             << marked_for_eviction_size_ << " bytes).";
217   }
218   UnloadAndDestroy(deleted_entry);
219   return OkStatus();
220 }
221 
UnloadAndDestroy(CompiledSubgraph * entry)222 void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
223   if (!entry) return;
224 
225   CHECK(entry->RefCountIsOne());
226   entry->tpu_program_group->UnloadAndDestroyPrograms();
227   entry->Unref();
228 }
229 
RemoveEntry(const std::string & key)230 size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
231   auto erased = cache_.erase(key);
232   TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
233 
234   auto parsed_key_or_status = ParseCompilationCacheKey(key);
235   CHECK(parsed_key_or_status.status().ok());
236   const TpuCompilationCacheKey parsed_key =
237       std::move(parsed_key_or_status).value();
238   if (!parsed_key.has_guaranteed_const) {
239     return erased;
240   }
241   session_key_map_.erase(
242       strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
243   fingerprint_key_map_.erase(strings::StrCat(
244       parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
245   return erased;
246 }
247 
DiscardEntryRef(CompiledSubgraph * entry)248 CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
249     CompiledSubgraph* entry) {
250   if (entry->RefCountIsOne()) {
251     // The last reference to this entry is going away, so really delete it from
252     // the cache in such a way that it can't be restored by being looked up
253     // again.
254 
255     // Sanity-check that it has been marked for eviction.
256     CHECK(entries_by_last_use_.find(entry->last_use) ==
257           entries_by_last_use_.end());
258     // Update the counter tracking how much space is taken up by entries that
259     // are marked for eviction.
260     marked_for_eviction_size_ -= entry->total_size;
261 
262     // Remove the entry from the cache.
263     auto erased = RemoveEntry(entry->subgraph_key);
264 
265     if (erased == 0) {
266       LOG(FATAL) << "Tried to discard nonexistent cache entry";
267     }
268     erased = entries_by_uid_.erase(entry->uid);
269     CHECK_EQ(erased, 1);
270     for (const std::string& key : entry->proto_key) {
271       erased = entries_by_proto_key_.erase(key);
272       CHECK_EQ(erased, 1);
273     }
274     // The actual deletion will happen outside the lock in UnloadAndDestroy().
275     return entry;
276   }
277   entry->Unref();
278   return nullptr;
279 }
280 
MakePerStepRefHolder()281 CompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
282   return new RefHolder(this);
283 }
284 
DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph * > entries)285 void TpuCompilationCacheInterface::DiscardEntryRefs(
286     gtl::ArraySlice<CompiledSubgraph*> entries) {
287   std::vector<CompiledSubgraph*> removed_entries;
288   {
289     absl::MutexLock lock(&mu_);
290 
291     for (auto entry : entries) {
292       removed_entries.push_back(DiscardEntryRef(entry));
293     }
294 
295     VLOG(1) << "After discarding entry refs cache is " << cache_.size()
296             << " entries (" << cache_size_ + marked_for_eviction_size_
297             << " bytes), marked for eviction "
298             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
299             << marked_for_eviction_size_ << " bytes).";
300   }
301   for (auto removed_entry : removed_entries) {
302     UnloadAndDestroy(removed_entry);
303   }
304 }
305 
MarkOldestEntryForEviction()306 CompiledSubgraph* TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
307   CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
308   VLOG(1) << "Marking " << entry_to_mark->subgraph_key
309           << " for eviction. Debug string: "
310           << entry_to_mark->cache_entry_debug_string;
311   entries_by_last_use_.erase(entry_to_mark->last_use);
312   cache_size_ -= entry_to_mark->total_size;
313   marked_for_eviction_size_ += entry_to_mark->total_size;
314   // Discard the cache's reference to entry. If steps are holding onto
315   // references to entry it won't be deleted until the last step holding it
316   // completes. It stays in the cache in the meantime and can be resurrected
317   // by a call to CompileIfKeyAbsent if that occurs before the last reference
318   // expires.
319   return DiscardEntryRef(entry_to_mark);
320 }
321 
LookupEntryMarkedForEviction(CompiledSubgraph * entry,std::vector<CompiledSubgraph * > * removed_entries)322 void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
323     CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
324   // The entry was previously marked for eviction (or is newly created) so
325   // unmark it. Add a reference (owned by the cache), update the cache size, and
326   // mark something old for eviction if necessary.
327   entry->Ref();
328   marked_for_eviction_size_ -= entry->total_size;
329   cache_size_ += entry->total_size;
330 
331   // Mark the least-recently-used non-marked entry for eviction. Never mark the
332   // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
333   // which means there's only one entry not already marked for eviction), so
334   // that an entry persists in the cache even if it is larger than the allocated
335   // cache size.
336   while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
337     if (auto entry_to_evict = MarkOldestEntryForEviction()) {
338       removed_entries->push_back(entry_to_evict);
339     }
340   }
341 }
342 
InsertEntry(const std::string & key,CompiledSubgraph * entry)343 void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
344                                                CompiledSubgraph* entry) {
345   auto cache_inserted =
346       cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
347   CHECK(cache_inserted.second);
348   TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
349 
350   auto parsed_key_or_status = ParseCompilationCacheKey(key);
351   CHECK(parsed_key_or_status.status().ok());
352   const TpuCompilationCacheKey parsed_key =
353       std::move(parsed_key_or_status).value();
354   if (!parsed_key.has_guaranteed_const) {
355     return;
356   }
357   session_key_map_.insert(std::make_pair(
358       strings::StrCat(parsed_key.prefix, parsed_key.session_handle), key));
359   fingerprint_key_map_.insert(
360       std::make_pair(strings::StrCat(parsed_key.prefix,
361                                      parsed_key.guaranteed_const_fingerprint()),
362                      key));
363 }
364 
CompileIfKeyAbsent(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64_t * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)365 Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
366     const TpuCompilationCacheKey& subgraph_key,
367     const SessionMetadata* session_metadata,
368     CompilationRefHolder* per_step_ref_holder, int64_t* uid,
369     std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
370     std::vector<bool>* may_modify_variables,
371     absl::Span<const xla::HloProto* const>* hlo_metadatas,
372     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
373   std::vector<CompiledSubgraph*> removed_entries;
374   auto status = CompileIfKeyAbsentHelper(
375       subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
376       sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
377       compile_function);
378   for (auto entry : removed_entries) {
379     UnloadAndDestroy(entry);
380   }
381   return status;
382 }
383 
FindCacheKey(const TpuCompilationCacheKey & subgraph_key)384 std::string TpuCompilationCacheInterface::FindCacheKey(
385     const TpuCompilationCacheKey& subgraph_key) {
386   if (!subgraph_key.has_guaranteed_const) {
387     return subgraph_key.prefix;
388   }
389   auto iter = session_key_map_.find(
390       strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
391   if (iter != session_key_map_.end()) {
392     return iter->second;
393   }
394   iter = fingerprint_key_map_.find(strings::StrCat(
395       subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
396   if (iter != session_key_map_.end()) {
397     return iter->second;
398   }
399   VLOG(1) << "No matching cache key found for key " << subgraph_key.ToString();
400   return "";
401 }
402 
CompileIfKeyAbsentHelper(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64_t * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,std::vector<CompiledSubgraph * > * removed_entries,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)403 Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
404     const TpuCompilationCacheKey& subgraph_key,
405     const SessionMetadata* session_metadata,
406     CompilationRefHolder* per_step_ref_holder, int64_t* uid,
407     std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
408     std::vector<bool>* may_modify_variables,
409     std::vector<CompiledSubgraph*>* removed_entries,
410     absl::Span<const xla::HloProto* const>* hlo_metadatas,
411     const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
412   CompiledSubgraph* entry = nullptr;
413 
414   profiler::TraceMe subgraph_lookup_traceme(
415       "TPU compilation cache subgraph lookup",
416       /*level=*/2);
417 
418   // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
419   // for the lifetime of the object, see InitializeEntry() call below.
420   absl::MutexLock lock(&mu_);
421 
422   std::string cache_key = FindCacheKey(subgraph_key);
423   auto iter = cache_.find(cache_key);
424   bool is_new_key = iter == cache_.end();
425 
426   const std::string session_name =
427       tpu::SessionNameFromMetadata(session_metadata);
428 
429   if (is_new_key) {
430     cache_key = subgraph_key.ToString();
431     TpuCompilationMetrics::IncrementCacheLookupCount(
432         /*is_cache_hit=*/false, session_name);
433     const std::string msg =
434         strings::StrCat("TPU host compilation cache miss: cache_key(",
435                         cache_key, "), session_name(", session_name, ")");
436     TRACESTRING(msg);
437     LOG(INFO) << msg;
438 
439     // Check if caller has disabled compilation. Set using
440     // internal::ScopedTpuCompileDisabler.
441     if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
442       const std::string error_msg = strings::StrCat(
443           "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
444           "disabled, session_name(",
445           session_name, ") Debug String: ", subgraph_key.debug_string);
446       if (VLOG_IS_ON(2)) {
447         VLOG(2) << "Cache Missed. Current cache entries: ";
448         for (auto it = cache_.begin(); it != cache_.end(); ++it) {
449           VLOG(2) << "Cache Debug Info: ";
450           VLOG(2) << it->second->cache_entry_debug_string;
451         }
452       }
453 
454       LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
455       return errors::NotFound(error_msg);
456     }
457 
458     // The single ref on the newly-created entry is owned by the caller.
459     VLOG(1) << "Before adding new entry for key " << cache_key
460             << " with session_name( " << session_name << ");"
461             << "; cache is " << cache_.size() << " entries ("
462             << cache_size_ + marked_for_eviction_size_ << " bytes), "
463             << " marked for eviction "
464             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
465             << marked_for_eviction_size_ << " bytes).";
466     // Note that InitializeEntry() will Release/Reacquire mu_.
467     entry = InitializeEntry(cache_key, compile_function, subgraph_key);
468     bool compilation_success = entry->tpu_program_group->program_count() > 0;
469     TRACELITERAL("TPU host compilation cache: compilation done.");
470     LOG(INFO) << strings::StrCat(
471         "TPU host compilation cache: compilation ",
472         compilation_success ? "complete" : "failed", " for cache_key(",
473         cache_key, "), session_name(", session_name, "), subgraph_key(",
474         subgraph_key.debug_string, ")");
475     // If session_name is present, log some additional stats related to HBM
476     // here, so that they can be associated directly to the session.
477     if (!session_name.empty()) {
478       entry->tpu_program_group->LogProgramMemorySummary();
479     }
480   } else {
481     TpuCompilationMetrics::IncrementCacheLookupCount(
482         /*is_cache_hit=*/true, session_name);
483     const std::string msg =
484         strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
485                         "), session_name(", session_name, ")");
486     TRACESTRING(msg);
487     VLOG(1) << msg;
488     VLOG(1) << "Before refreshing entry for key " << cache_key
489             << " with session_name( " << session_name << "); cache is "
490             << cache_.size() << " entries ("
491             << cache_size_ + marked_for_eviction_size_ << " bytes), "
492             << " marked for eviction "
493             << (cache_.size() - entries_by_last_use_.size()) << " entries ("
494             << marked_for_eviction_size_ << " bytes).";
495     entry = iter->second;
496     // Make a new reference that is owned by the caller.
497     entry->Ref();
498     // Block if necessary until the subgraph has been initialized.
499     mu_.Await(absl::Condition(
500         +[](CompiledSubgraph* e) { return e->initialized; }, entry));
501   }
502 
503   // Let the caller know the uid of the entry.
504   *uid = entry->uid;
505   // Let the caller know the keys for each of the cached protos.
506   *proto_key = entry->proto_key;
507   *sharding_key = entry->sharding_key;
508   *may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
509   *hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
510 
511   // If the caller didn't supply a per_step_ref_holder then the caller is going
512   // to manually release the reference later via a call to Release().
513   if (per_step_ref_holder == nullptr) {
514     ++entry->external_references;
515   } else {
516     // The caller wants its reference to be handed off to a per-step holder that
517     // will discard the reference when the step completes.
518     RefHolder* cast_ref_holder =
519         tensorflow::down_cast<RefHolder*>(per_step_ref_holder);
520     CHECK_NE(cast_ref_holder, nullptr);
521     cast_ref_holder->AddRef(entry);
522   }
523 
524   // Remove the old LRU-table entry if it wasn't already marked for eviction.
525   auto erased = entries_by_last_use_.erase(entry->last_use);
526   // Update the LRU table indicating this entry is the most recently used.
527   entry->last_use = use_counter_++;
528   entries_by_last_use_[entry->last_use] = entry;
529   if (erased == 0) {
530     // The entry had been marked for eviction, or is newly created.
531     LookupEntryMarkedForEviction(entry, removed_entries);
532   }
533 
534   // Log a little more verbosely when a key is added.
535   if (VLOG_IS_ON(1) || is_new_key) {
536     LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
537               << " entry for key " << cache_key << " with session_name "
538               << session_name << " cache is " << cache_.size() << " entries ("
539               << cache_size_ + marked_for_eviction_size_ << " bytes), "
540               << " marked for eviction "
541               << (cache_.size() - entries_by_last_use_.size()) << " entries ("
542               << marked_for_eviction_size_ << " bytes).";
543   }
544   return entry->initialization_status;
545 }
546 
GetKeysFromUid(int64_t uid,std::vector<std::string> * keys)547 Status TpuCompilationCacheInterface::GetKeysFromUid(
548     int64_t uid, std::vector<std::string>* keys) {
549   keys->clear();
550 
551   absl::MutexLock lock(&mu_);
552   const auto iter = entries_by_uid_.find(uid);
553   if (iter == entries_by_uid_.end()) {
554     return errors::NotFound("No subgraph found for uid ", uid);
555   }
556   *keys = iter->second->proto_key;
557   return OkStatus();
558 }
559 
Lookup(int64_t uid,int proto_index,std::unique_ptr<CompilationCacheEntryRef> * entry)560 Status TpuCompilationCacheInterface::Lookup(
561     int64_t uid, int proto_index,
562     std::unique_ptr<CompilationCacheEntryRef>* entry) {
563   entry->reset();
564 
565   profiler::TraceMe proto_lookup_traceme(
566       "TPU compilation cache proto lookup by uid",
567       /*level=*/2);
568 
569   absl::MutexLock lock(&mu_);
570   const auto iter = entries_by_uid_.find(uid);
571   if (iter == entries_by_uid_.end()) {
572     return errors::NotFound("No subgraph found for uid ", uid);
573   }
574   CompiledSubgraph* cache_entry = iter->second;
575   if (proto_index < 0 ||
576       proto_index >= cache_entry->tpu_program_group->program_count()) {
577     return errors::NotFound("No proto found for core index ", proto_index,
578                             " in subgraph with uid ", uid);
579   }
580   *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
581                                                        proto_index);
582   return OkStatus();
583 }
584 
Lookup(const std::string & proto_key,std::unique_ptr<CompilationCacheEntryRef> * entry)585 Status TpuCompilationCacheInterface::Lookup(
586     const std::string& proto_key,
587     std::unique_ptr<CompilationCacheEntryRef>* entry) {
588   entry->reset();
589 
590   profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
591                                          /*level=*/2);
592 
593   absl::MutexLock lock(&mu_);
594   const auto iter = entries_by_proto_key_.find(proto_key);
595   if (iter == entries_by_proto_key_.end()) {
596     return errors::NotFound("No proto found for key ", proto_key);
597   }
598   CompiledSubgraph* cache_entry = iter->second.first;
599   int proto_index = iter->second.second;
600   *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
601                                                        proto_index);
602   return OkStatus();
603 }
604 }  // namespace tpu
605 }  // namespace tensorflow
606