1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/platform/casts.h"
20 #include "tensorflow/core/tpu/kernels/tpu_util.h"
21 #include "tensorflow/core/tpu/tpu_api.h"
22
23 namespace tensorflow {
24 namespace tpu {
25
RefHolder(TpuCompilationCacheInterface * parent)26 TpuCompilationCacheInterface::RefHolder::RefHolder(
27 TpuCompilationCacheInterface* parent)
28 : parent_(parent) {
29 // Hold a reference to the parent until the holder is discarded.
30 parent_->Ref();
31 }
32
~RefHolder()33 TpuCompilationCacheInterface::RefHolder::~RefHolder() {
34 parent_->DiscardEntryRefs(entries_);
35 // Release our reference to the parent.
36 parent_->Unref();
37 }
38
AddRef(CompiledSubgraph * entry)39 void TpuCompilationCacheInterface::RefHolder::AddRef(CompiledSubgraph* entry) {
40 entries_.push_back(entry);
41 }
42
DebugString() const43 std::string TpuCompilationCacheInterface::RefHolder::DebugString() const {
44 return "TpuCompilationCacheRefHolder";
45 }
46
CompilationCacheEntryRef()47 CompilationCacheEntryRef::CompilationCacheEntryRef()
48 : parent_(nullptr), entry_(nullptr), index_(0) {}
49
CompilationCacheEntryRef(TpuCompilationCacheInterface * parent,CompiledSubgraph * entry,int index)50 CompilationCacheEntryRef::CompilationCacheEntryRef(
51 TpuCompilationCacheInterface* parent, CompiledSubgraph* entry, int index)
52 : parent_(parent), entry_(entry), index_(index) {
53 if (entry_ == nullptr) {
54 return;
55 }
56 if (entry_->main_entry == nullptr) {
57 entry_->Ref();
58 } else {
59 // This is a sharding/unsharding entry nested in a main entry. Only
60 // refcount the main entry.
61 entry_->main_entry->Ref();
62 }
63 }
64
~CompilationCacheEntryRef()65 CompilationCacheEntryRef::~CompilationCacheEntryRef() {
66 if (entry_ == nullptr) {
67 return;
68 }
69 if (entry_->main_entry == nullptr) {
70 parent_->DiscardEntryRefs({entry_});
71 } else {
72 parent_->DiscardEntryRefs({entry_->main_entry});
73 }
74 }
75
get()76 TpuCompilationCacheEntry CompilationCacheEntryRef::get() {
77 if (entry_ == nullptr) {
78 // Create an empty entry if the entry is nullptr. This corresponds to
79 // non-existing sharding/unsharding entries.
80 return TpuCompilationCacheEntry();
81 }
82
83 return TpuCompilationCacheEntry(entry_->tpu_program_group.get(), index_);
84 }
85
ToSubEntryRef(CompilationCacheFetchTarget fetch_target)86 Status CompilationCacheEntryRef::ToSubEntryRef(
87 CompilationCacheFetchTarget fetch_target) {
88 CompiledSubgraph* target = nullptr;
89 switch (fetch_target) {
90 case CompilationCacheFetchTarget::MAIN:
91 target = entry_;
92 break;
93 case CompilationCacheFetchTarget::SHARDING:
94 target = entry_->sharding_entry.get();
95 break;
96 case CompilationCacheFetchTarget::UNSHARDING:
97 target = entry_->unsharding_entry.get();
98 break;
99 default:
100 return xla::InvalidArgument("Invalid fetch target: %d", fetch_target);
101 }
102
103 if (target == nullptr) {
104 // Cache entry does not have an unsharding subentry. Unref and replace
105 // with nullptr.
106 parent_->DiscardEntryRefs({entry_});
107 }
108 // Otherwise, since the refcount is always on the main entry, we don't
109 // need ref/unref.
110 entry_ = target;
111 return OkStatus();
112 }
113
TpuCompilationCacheInterface(int64_t max_cache_size)114 TpuCompilationCacheInterface::TpuCompilationCacheInterface(
115 int64_t max_cache_size)
116 : max_cache_size_(max_cache_size) {
117 CHECK_GE(max_cache_size_, 0);
118 VLOG(1) << "Created compilation cache size " << max_cache_size_ << " bytes.";
119 }
120
~TpuCompilationCacheInterface()121 TpuCompilationCacheInterface::~TpuCompilationCacheInterface() {
122 VLOG(1) << "TpuCompilationCacheInterface::~TpuCompilationCacheInterface()";
123 // A buggy client may be holding onto a reference, or a client might have
124 // crashed while holding onto a reference. In either case, discard all
125 // outstanding client references to avoid leaking storage.
126 for (const auto& entry : entries_by_uid_) {
127 while (entry.second->external_references > 0) {
128 Status s = Release(entry.first);
129 CHECK(s.ok());
130 }
131 }
132 while (!entries_by_last_use_.empty()) {
133 UnloadAndDestroy(MarkOldestEntryForEviction());
134 }
135 // By the time the cache is deleted all reference holders should have already
136 // been deleted, since they were holding references to the cache. So all
137 // entries should be gone at this point.
138 CHECK_EQ(cache_.size(), 0);
139 CHECK_EQ(entries_by_uid_.size(), 0);
140 CHECK_EQ(entries_by_proto_key_.size(), 0);
141 CHECK_EQ(cache_size_, 0);
142 CHECK_EQ(marked_for_eviction_size_, 0);
143 }
144
MarkEntryForEviction(int64_t subgraph_uid)145 Status TpuCompilationCacheInterface::MarkEntryForEviction(
146 int64_t subgraph_uid) {
147 profiler::TraceMe key_release_traceme(
148 "TPU compilation cache possibly evict uid",
149 /*level=*/2);
150 CompiledSubgraph* deleted_entry = nullptr;
151 {
152 absl::MutexLock lock(&mu_);
153 auto iter = entries_by_uid_.find(subgraph_uid);
154 if (iter == entries_by_uid_.end()) {
155 // If already evicted, return ok.
156 return OkStatus();
157 }
158
159 // Mark entry for eviction.
160 CompiledSubgraph* subgraph_to_evict = iter->second;
161 // If there are external references, should not use this API.
162 if (subgraph_to_evict->external_references != 0) {
163 return errors::Internal("Subgraph ", subgraph_to_evict->subgraph_key,
164 " external_references greater than zero. Should "
165 "use TpuCompilationCacheInterface::Release.");
166 }
167
168 VLOG(1) << "Marking " << subgraph_to_evict->subgraph_key
169 << " for eviction. Debug string: "
170 << subgraph_to_evict->cache_entry_debug_string;
171 entries_by_last_use_.erase(subgraph_to_evict->last_use);
172 cache_size_ -= subgraph_to_evict->total_size;
173 marked_for_eviction_size_ += subgraph_to_evict->total_size;
174
175 // Evict if refcount exactly one, otherwise only discard cache's reference
176 // to the entry while the actual eviction will happen when refholder's
177 // references go away.
178 deleted_entry = DiscardEntryRef(subgraph_to_evict);
179
180 VLOG(1) << "After possibly evicting entry " << subgraph_uid
181 << " refs cache is " << cache_.size() << " entries ("
182 << cache_size_ + marked_for_eviction_size_
183 << " bytes), marked for eviction "
184 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
185 << marked_for_eviction_size_ << " bytes).";
186 }
187
188 // Unload from device cache if entry is evicted from host cache.
189 UnloadAndDestroy(deleted_entry);
190 return OkStatus();
191 }
192
Release(int64_t subgraph_uid)193 Status TpuCompilationCacheInterface::Release(int64_t subgraph_uid) {
194 profiler::TraceMe key_release_traceme("TPU compilation cache release uid",
195 /*level=*/2);
196
197 CompiledSubgraph* deleted_entry = nullptr;
198 {
199 absl::MutexLock lock(&mu_);
200 auto iter = entries_by_uid_.find(subgraph_uid);
201
202 if (iter == entries_by_uid_.end()) {
203 return errors::NotFound("No cache entry found for uid ", subgraph_uid);
204 }
205
206 CHECK_GT(iter->second->external_references, 0);
207 --iter->second->external_references;
208
209 deleted_entry = DiscardEntryRef(iter->second);
210
211 VLOG(1) << "After releasing entry " << subgraph_uid << " refs cache is "
212 << cache_.size() << " entries ("
213 << cache_size_ + marked_for_eviction_size_
214 << " bytes), marked for eviction "
215 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
216 << marked_for_eviction_size_ << " bytes).";
217 }
218 UnloadAndDestroy(deleted_entry);
219 return OkStatus();
220 }
221
UnloadAndDestroy(CompiledSubgraph * entry)222 void TpuCompilationCacheInterface::UnloadAndDestroy(CompiledSubgraph* entry) {
223 if (!entry) return;
224
225 CHECK(entry->RefCountIsOne());
226 entry->tpu_program_group->UnloadAndDestroyPrograms();
227 entry->Unref();
228 }
229
RemoveEntry(const std::string & key)230 size_t TpuCompilationCacheInterface::RemoveEntry(const std::string& key) {
231 auto erased = cache_.erase(key);
232 TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
233
234 auto parsed_key_or_status = ParseCompilationCacheKey(key);
235 CHECK(parsed_key_or_status.status().ok());
236 const TpuCompilationCacheKey parsed_key =
237 std::move(parsed_key_or_status).value();
238 if (!parsed_key.has_guaranteed_const) {
239 return erased;
240 }
241 session_key_map_.erase(
242 strings::StrCat(parsed_key.prefix, parsed_key.session_handle));
243 fingerprint_key_map_.erase(strings::StrCat(
244 parsed_key.prefix, parsed_key.guaranteed_const_fingerprint()));
245 return erased;
246 }
247
DiscardEntryRef(CompiledSubgraph * entry)248 CompiledSubgraph* TpuCompilationCacheInterface::DiscardEntryRef(
249 CompiledSubgraph* entry) {
250 if (entry->RefCountIsOne()) {
251 // The last reference to this entry is going away, so really delete it from
252 // the cache in such a way that it can't be restored by being looked up
253 // again.
254
255 // Sanity-check that it has been marked for eviction.
256 CHECK(entries_by_last_use_.find(entry->last_use) ==
257 entries_by_last_use_.end());
258 // Update the counter tracking how much space is taken up by entries that
259 // are marked for eviction.
260 marked_for_eviction_size_ -= entry->total_size;
261
262 // Remove the entry from the cache.
263 auto erased = RemoveEntry(entry->subgraph_key);
264
265 if (erased == 0) {
266 LOG(FATAL) << "Tried to discard nonexistent cache entry";
267 }
268 erased = entries_by_uid_.erase(entry->uid);
269 CHECK_EQ(erased, 1);
270 for (const std::string& key : entry->proto_key) {
271 erased = entries_by_proto_key_.erase(key);
272 CHECK_EQ(erased, 1);
273 }
274 // The actual deletion will happen outside the lock in UnloadAndDestroy().
275 return entry;
276 }
277 entry->Unref();
278 return nullptr;
279 }
280
MakePerStepRefHolder()281 CompilationRefHolder* TpuCompilationCacheInterface::MakePerStepRefHolder() {
282 return new RefHolder(this);
283 }
284
DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph * > entries)285 void TpuCompilationCacheInterface::DiscardEntryRefs(
286 gtl::ArraySlice<CompiledSubgraph*> entries) {
287 std::vector<CompiledSubgraph*> removed_entries;
288 {
289 absl::MutexLock lock(&mu_);
290
291 for (auto entry : entries) {
292 removed_entries.push_back(DiscardEntryRef(entry));
293 }
294
295 VLOG(1) << "After discarding entry refs cache is " << cache_.size()
296 << " entries (" << cache_size_ + marked_for_eviction_size_
297 << " bytes), marked for eviction "
298 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
299 << marked_for_eviction_size_ << " bytes).";
300 }
301 for (auto removed_entry : removed_entries) {
302 UnloadAndDestroy(removed_entry);
303 }
304 }
305
MarkOldestEntryForEviction()306 CompiledSubgraph* TpuCompilationCacheInterface::MarkOldestEntryForEviction() {
307 CompiledSubgraph* entry_to_mark = entries_by_last_use_.begin()->second;
308 VLOG(1) << "Marking " << entry_to_mark->subgraph_key
309 << " for eviction. Debug string: "
310 << entry_to_mark->cache_entry_debug_string;
311 entries_by_last_use_.erase(entry_to_mark->last_use);
312 cache_size_ -= entry_to_mark->total_size;
313 marked_for_eviction_size_ += entry_to_mark->total_size;
314 // Discard the cache's reference to entry. If steps are holding onto
315 // references to entry it won't be deleted until the last step holding it
316 // completes. It stays in the cache in the meantime and can be resurrected
317 // by a call to CompileIfKeyAbsent if that occurs before the last reference
318 // expires.
319 return DiscardEntryRef(entry_to_mark);
320 }
321
LookupEntryMarkedForEviction(CompiledSubgraph * entry,std::vector<CompiledSubgraph * > * removed_entries)322 void TpuCompilationCacheInterface::LookupEntryMarkedForEviction(
323 CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries) {
324 // The entry was previously marked for eviction (or is newly created) so
325 // unmark it. Add a reference (owned by the cache), update the cache size, and
326 // mark something old for eviction if necessary.
327 entry->Ref();
328 marked_for_eviction_size_ -= entry->total_size;
329 cache_size_ += entry->total_size;
330
331 // Mark the least-recently-used non-marked entry for eviction. Never mark the
332 // most-recently used entry (i.e., do nothing if entries_by_last_use_ == 1
333 // which means there's only one entry not already marked for eviction), so
334 // that an entry persists in the cache even if it is larger than the allocated
335 // cache size.
336 while (entries_by_last_use_.size() > 1 && cache_size_ > max_cache_size_) {
337 if (auto entry_to_evict = MarkOldestEntryForEviction()) {
338 removed_entries->push_back(entry_to_evict);
339 }
340 }
341 }
342
InsertEntry(const std::string & key,CompiledSubgraph * entry)343 void TpuCompilationCacheInterface::InsertEntry(const std::string& key,
344 CompiledSubgraph* entry) {
345 auto cache_inserted =
346 cache_.insert(std::pair<std::string, CompiledSubgraph*>(key, entry));
347 CHECK(cache_inserted.second);
348 TpuCompilationMetrics::SetCacheEntryCount(cache_.size());
349
350 auto parsed_key_or_status = ParseCompilationCacheKey(key);
351 CHECK(parsed_key_or_status.status().ok());
352 const TpuCompilationCacheKey parsed_key =
353 std::move(parsed_key_or_status).value();
354 if (!parsed_key.has_guaranteed_const) {
355 return;
356 }
357 session_key_map_.insert(std::make_pair(
358 strings::StrCat(parsed_key.prefix, parsed_key.session_handle), key));
359 fingerprint_key_map_.insert(
360 std::make_pair(strings::StrCat(parsed_key.prefix,
361 parsed_key.guaranteed_const_fingerprint()),
362 key));
363 }
364
CompileIfKeyAbsent(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64_t * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)365 Status TpuCompilationCacheInterface::CompileIfKeyAbsent(
366 const TpuCompilationCacheKey& subgraph_key,
367 const SessionMetadata* session_metadata,
368 CompilationRefHolder* per_step_ref_holder, int64_t* uid,
369 std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
370 std::vector<bool>* may_modify_variables,
371 absl::Span<const xla::HloProto* const>* hlo_metadatas,
372 const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
373 std::vector<CompiledSubgraph*> removed_entries;
374 auto status = CompileIfKeyAbsentHelper(
375 subgraph_key, session_metadata, per_step_ref_holder, uid, proto_key,
376 sharding_key, may_modify_variables, &removed_entries, hlo_metadatas,
377 compile_function);
378 for (auto entry : removed_entries) {
379 UnloadAndDestroy(entry);
380 }
381 return status;
382 }
383
FindCacheKey(const TpuCompilationCacheKey & subgraph_key)384 std::string TpuCompilationCacheInterface::FindCacheKey(
385 const TpuCompilationCacheKey& subgraph_key) {
386 if (!subgraph_key.has_guaranteed_const) {
387 return subgraph_key.prefix;
388 }
389 auto iter = session_key_map_.find(
390 strings::StrCat(subgraph_key.prefix, subgraph_key.session_handle));
391 if (iter != session_key_map_.end()) {
392 return iter->second;
393 }
394 iter = fingerprint_key_map_.find(strings::StrCat(
395 subgraph_key.prefix, subgraph_key.guaranteed_const_fingerprint()));
396 if (iter != session_key_map_.end()) {
397 return iter->second;
398 }
399 VLOG(1) << "No matching cache key found for key " << subgraph_key.ToString();
400 return "";
401 }
402
CompileIfKeyAbsentHelper(const TpuCompilationCacheKey & subgraph_key,const SessionMetadata * session_metadata,CompilationRefHolder * per_step_ref_holder,int64_t * uid,std::vector<std::string> * proto_key,std::vector<std::string> * sharding_key,std::vector<bool> * may_modify_variables,std::vector<CompiledSubgraph * > * removed_entries,absl::Span<const xla::HloProto * const> * hlo_metadatas,const std::function<Status (TpuProgramGroupInterface *)> & compile_function)403 Status TpuCompilationCacheInterface::CompileIfKeyAbsentHelper(
404 const TpuCompilationCacheKey& subgraph_key,
405 const SessionMetadata* session_metadata,
406 CompilationRefHolder* per_step_ref_holder, int64_t* uid,
407 std::vector<std::string>* proto_key, std::vector<std::string>* sharding_key,
408 std::vector<bool>* may_modify_variables,
409 std::vector<CompiledSubgraph*>* removed_entries,
410 absl::Span<const xla::HloProto* const>* hlo_metadatas,
411 const std::function<Status(TpuProgramGroupInterface*)>& compile_function) {
412 CompiledSubgraph* entry = nullptr;
413
414 profiler::TraceMe subgraph_lookup_traceme(
415 "TPU compilation cache subgraph lookup",
416 /*level=*/2);
417
418 // NOTE: In spite of the fact that we use MutexLock, we do not hold the lock
419 // for the lifetime of the object, see InitializeEntry() call below.
420 absl::MutexLock lock(&mu_);
421
422 std::string cache_key = FindCacheKey(subgraph_key);
423 auto iter = cache_.find(cache_key);
424 bool is_new_key = iter == cache_.end();
425
426 const std::string session_name =
427 tpu::SessionNameFromMetadata(session_metadata);
428
429 if (is_new_key) {
430 cache_key = subgraph_key.ToString();
431 TpuCompilationMetrics::IncrementCacheLookupCount(
432 /*is_cache_hit=*/false, session_name);
433 const std::string msg =
434 strings::StrCat("TPU host compilation cache miss: cache_key(",
435 cache_key, "), session_name(", session_name, ")");
436 TRACESTRING(msg);
437 LOG(INFO) << msg;
438
439 // Check if caller has disabled compilation. Set using
440 // internal::ScopedTpuCompileDisabler.
441 if (!OpsApiFn()->TpuCompile_IsTpuCompilationEnabledFn()) {
442 const std::string error_msg = strings::StrCat(
443 "[TpuCompilationDisabled]: Compilation cache miss, but compilation "
444 "disabled, session_name(",
445 session_name, ") Debug String: ", subgraph_key.debug_string);
446 if (VLOG_IS_ON(2)) {
447 VLOG(2) << "Cache Missed. Current cache entries: ";
448 for (auto it = cache_.begin(); it != cache_.end(); ++it) {
449 VLOG(2) << "Cache Debug Info: ";
450 VLOG(2) << it->second->cache_entry_debug_string;
451 }
452 }
453
454 LOG_EVERY_N_SEC(WARNING, 30) << error_msg;
455 return errors::NotFound(error_msg);
456 }
457
458 // The single ref on the newly-created entry is owned by the caller.
459 VLOG(1) << "Before adding new entry for key " << cache_key
460 << " with session_name( " << session_name << ");"
461 << "; cache is " << cache_.size() << " entries ("
462 << cache_size_ + marked_for_eviction_size_ << " bytes), "
463 << " marked for eviction "
464 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
465 << marked_for_eviction_size_ << " bytes).";
466 // Note that InitializeEntry() will Release/Reacquire mu_.
467 entry = InitializeEntry(cache_key, compile_function, subgraph_key);
468 bool compilation_success = entry->tpu_program_group->program_count() > 0;
469 TRACELITERAL("TPU host compilation cache: compilation done.");
470 LOG(INFO) << strings::StrCat(
471 "TPU host compilation cache: compilation ",
472 compilation_success ? "complete" : "failed", " for cache_key(",
473 cache_key, "), session_name(", session_name, "), subgraph_key(",
474 subgraph_key.debug_string, ")");
475 // If session_name is present, log some additional stats related to HBM
476 // here, so that they can be associated directly to the session.
477 if (!session_name.empty()) {
478 entry->tpu_program_group->LogProgramMemorySummary();
479 }
480 } else {
481 TpuCompilationMetrics::IncrementCacheLookupCount(
482 /*is_cache_hit=*/true, session_name);
483 const std::string msg =
484 strings::StrCat("TPU host compilation cache hit: cache_key(", cache_key,
485 "), session_name(", session_name, ")");
486 TRACESTRING(msg);
487 VLOG(1) << msg;
488 VLOG(1) << "Before refreshing entry for key " << cache_key
489 << " with session_name( " << session_name << "); cache is "
490 << cache_.size() << " entries ("
491 << cache_size_ + marked_for_eviction_size_ << " bytes), "
492 << " marked for eviction "
493 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
494 << marked_for_eviction_size_ << " bytes).";
495 entry = iter->second;
496 // Make a new reference that is owned by the caller.
497 entry->Ref();
498 // Block if necessary until the subgraph has been initialized.
499 mu_.Await(absl::Condition(
500 +[](CompiledSubgraph* e) { return e->initialized; }, entry));
501 }
502
503 // Let the caller know the uid of the entry.
504 *uid = entry->uid;
505 // Let the caller know the keys for each of the cached protos.
506 *proto_key = entry->proto_key;
507 *sharding_key = entry->sharding_key;
508 *may_modify_variables = entry->tpu_program_group->may_modify_variables_list();
509 *hlo_metadatas = entry->tpu_program_group->hlo_metadatas();
510
511 // If the caller didn't supply a per_step_ref_holder then the caller is going
512 // to manually release the reference later via a call to Release().
513 if (per_step_ref_holder == nullptr) {
514 ++entry->external_references;
515 } else {
516 // The caller wants its reference to be handed off to a per-step holder that
517 // will discard the reference when the step completes.
518 RefHolder* cast_ref_holder =
519 tensorflow::down_cast<RefHolder*>(per_step_ref_holder);
520 CHECK_NE(cast_ref_holder, nullptr);
521 cast_ref_holder->AddRef(entry);
522 }
523
524 // Remove the old LRU-table entry if it wasn't already marked for eviction.
525 auto erased = entries_by_last_use_.erase(entry->last_use);
526 // Update the LRU table indicating this entry is the most recently used.
527 entry->last_use = use_counter_++;
528 entries_by_last_use_[entry->last_use] = entry;
529 if (erased == 0) {
530 // The entry had been marked for eviction, or is newly created.
531 LookupEntryMarkedForEviction(entry, removed_entries);
532 }
533
534 // Log a little more verbosely when a key is added.
535 if (VLOG_IS_ON(1) || is_new_key) {
536 LOG(INFO) << "After " << (is_new_key ? "adding" : "refreshing")
537 << " entry for key " << cache_key << " with session_name "
538 << session_name << " cache is " << cache_.size() << " entries ("
539 << cache_size_ + marked_for_eviction_size_ << " bytes), "
540 << " marked for eviction "
541 << (cache_.size() - entries_by_last_use_.size()) << " entries ("
542 << marked_for_eviction_size_ << " bytes).";
543 }
544 return entry->initialization_status;
545 }
546
GetKeysFromUid(int64_t uid,std::vector<std::string> * keys)547 Status TpuCompilationCacheInterface::GetKeysFromUid(
548 int64_t uid, std::vector<std::string>* keys) {
549 keys->clear();
550
551 absl::MutexLock lock(&mu_);
552 const auto iter = entries_by_uid_.find(uid);
553 if (iter == entries_by_uid_.end()) {
554 return errors::NotFound("No subgraph found for uid ", uid);
555 }
556 *keys = iter->second->proto_key;
557 return OkStatus();
558 }
559
Lookup(int64_t uid,int proto_index,std::unique_ptr<CompilationCacheEntryRef> * entry)560 Status TpuCompilationCacheInterface::Lookup(
561 int64_t uid, int proto_index,
562 std::unique_ptr<CompilationCacheEntryRef>* entry) {
563 entry->reset();
564
565 profiler::TraceMe proto_lookup_traceme(
566 "TPU compilation cache proto lookup by uid",
567 /*level=*/2);
568
569 absl::MutexLock lock(&mu_);
570 const auto iter = entries_by_uid_.find(uid);
571 if (iter == entries_by_uid_.end()) {
572 return errors::NotFound("No subgraph found for uid ", uid);
573 }
574 CompiledSubgraph* cache_entry = iter->second;
575 if (proto_index < 0 ||
576 proto_index >= cache_entry->tpu_program_group->program_count()) {
577 return errors::NotFound("No proto found for core index ", proto_index,
578 " in subgraph with uid ", uid);
579 }
580 *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
581 proto_index);
582 return OkStatus();
583 }
584
Lookup(const std::string & proto_key,std::unique_ptr<CompilationCacheEntryRef> * entry)585 Status TpuCompilationCacheInterface::Lookup(
586 const std::string& proto_key,
587 std::unique_ptr<CompilationCacheEntryRef>* entry) {
588 entry->reset();
589
590 profiler::TraceMe proto_lookup_traceme("TPU compilation cache proto lookup",
591 /*level=*/2);
592
593 absl::MutexLock lock(&mu_);
594 const auto iter = entries_by_proto_key_.find(proto_key);
595 if (iter == entries_by_proto_key_.end()) {
596 return errors::NotFound("No proto found for key ", proto_key);
597 }
598 CompiledSubgraph* cache_entry = iter->second.first;
599 int proto_index = iter->second.second;
600 *entry = absl::make_unique<CompilationCacheEntryRef>(this, cache_entry,
601 proto_index);
602 return OkStatus();
603 }
604 } // namespace tpu
605 } // namespace tensorflow
606