xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
17 
18 #include <memory>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/base/thread_annotations.h"
23 #include "absl/container/node_hash_map.h"
24 #include "absl/strings/str_cat.h"
25 #include "absl/synchronization/mutex.h"
26 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/framework/resource_mgr.h"
29 #include "tensorflow/core/lib/core/refcount.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/profiler/lib/traceme.h"
32 #include "tensorflow/core/protobuf/config.pb.h"
33 #include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
37 #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
38 #include "tensorflow/core/tpu/kernels/trace_util.h"
39 
40 namespace tensorflow {
41 namespace tpu {
42 
43 // Base class that holds references to compiled protos so that the protos are
44 // not garbage-collected before being used by execute ops. Use
45 // TpuCompilationCache::MakePerStepRefHolder to create an instance of a concrete
46 // ref holder object.
47 class CompilationRefHolder : public ResourceBase {
48  public:
49   ~CompilationRefHolder() override = default;
50 };
51 
52 // Wrapper for a cache entry returned by all the TpuCompilationCacheInterface
53 // `Lookup` methods, and ensures the underlying proto is not garbage-collected
54 // until the client discards the ptr.
55 class CompilationCacheEntryRef {
56  public:
57   CompilationCacheEntryRef();
58   CompilationCacheEntryRef(TpuCompilationCacheInterface* parent,
59                            CompiledSubgraph* entry, int index);
60 
61   virtual ~CompilationCacheEntryRef();
62 
63   // Returns a TpuCompilationCacheEntry that should not be used beyond the
64   // lifetime of the CompilationCacheEntryRef.
65   virtual TpuCompilationCacheEntry get();
66 
67   // Mutates this ref to point to the entry's subentry (for
68   // sharding/unsharding) or main entry (unchanged) as specified by
69   // fetch_target. The refcount is kept unchanged, since we only track the
70   // refcount of the main entry. The entry ref needs to point to the main
71   // entry before this call.
72   //
73   // If the requested subentry does not exist, the ref will point to a nullptr
74   // entry, and the original entry will be unref'ed.
75   virtual Status ToSubEntryRef(CompilationCacheFetchTarget fetch_target);
76 
77  protected:
78   TpuCompilationCacheInterface* parent_;  // Not owned.
79   // A reference to entry_ is acquired in the constructor and released via
80   // parent->DiscardEntryRefs in the destructor.
81   CompiledSubgraph* entry_;
82   // The index of the program in entry_ that is returned by the get method.
83   int index_;
84 };
85 
86 class TpuCompilationCacheInterface : public ResourceBase {
87  public:
88   explicit TpuCompilationCacheInterface(int64_t max_cache_size);
89   ~TpuCompilationCacheInterface() override;
90 
91   // Ensures there is an entry for key present in the cache. By the time
92   // CompileIfKeyAbsent returns there is guaranteed to be an entry in the cache
93   // for key, and that entry will remain valid at least until
94   // per_step_ref_holder is deleted. The first call to CompileIfKeyAbsent with a
95   // key that is not in the cache will evaluate compile_function to compute the
96   // value to use in the entry. Subsequent calls with the same key will block
97   // until compile_function completes. Other cache reads and inserts may proceed
98   // on other threads while compile_function is executing. If
99   // per_step_ref_holder is nullptr then the caller is responsible for calling
100   // Release(subgraph_key) to manually discard its reference to the compiled
101   // program, once the caller will not look up the compiled program again.
102   //
103   // compile_function should compile the subgraph represented by key and fill in
104   // one TPUExecutableProto per model-parallel core into its passed argument. It
105   // should return OK if and only if compilation succeeds. The executable proto
106   // vector will be discarded on non-OK status.
107   Status CompileIfKeyAbsent(
108       const TpuCompilationCacheKey& subgraph_key,
109       const SessionMetadata* session_metadata,
110       CompilationRefHolder* per_step_ref_holder, int64_t* uid,
111       std::vector<std::string>* proto_key,
112       std::vector<std::string>* sharding_key,
113       std::vector<bool>* may_modify_variables,
114       absl::Span<const xla::HloProto* const>* hlo_metadatas,
115       const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
116 
117   // Differences between MarkEntryForEviction and Release:
118   // There are two modes of managing cache entries:
119   // 1) LRU eviction + pinning; 2) manual.
120   // We use mode 1) if CompilationRefHolder is provided to CompileIfKeyAbsent.
121   // Otherwise it is manual mode (mainly used by XRT).
122   // MarkEntryForEviction should only be used in mode 1) to eagerly evict cache
123   // entries when callers know that they do not need them anymore.
124   // Release should only be used in mode 2) to explicitly remove an entry.
125 
126   // Mark the entry indexed by `subgraph_uid` for eviction. This should only be
127   // called if per_step_ref_holder was NOT nullptr in the corresponding call to
128   // CompileIfKeyAbsent(subgraph_key, ...). Otherwise, use Release(int64
129   // subgraph_uid).
130   Status MarkEntryForEviction(int64_t subgraph_uid);
131 
132   // Manually discards a reference to the compiled subgraph. This should only be
133   // called if per_step_ref_holder was nullptr in the corresponding call to
134   // CompileIfKeyAbsent(subgraph_key, ...).
135   Status Release(int64_t subgraph_uid);
136 
137   // Looks up an executable corresponding to the model-parallel core index of
138   // the subgraph represented by key. On success a pointer to an EntryRef
139   // holding the program is returned in entry.
140   Status Lookup(const std::string& proto_key,
141                 std::unique_ptr<CompilationCacheEntryRef>* entry);
142 
143   // Looks up an executable corresponding to the model-parallel core index of
144   // the subgraph represented by uid. On success a pointer to an EntryRef
145   // holding the program is returned in entry.
146   Status Lookup(int64_t uid, int proto_index,
147                 std::unique_ptr<CompilationCacheEntryRef>* entry);
148 
149   // Looks up the subgraph represented by uid, and returns the vector of keys,
150   // one per core, corresponding to that subgraph.
151   Status GetKeysFromUid(int64_t uid, std::vector<std::string>* keys);
152 
153   // Makes a reference holder for this cache, that can be stored in the per-step
154   // resource manager and will ensure that compiled entries persist until the
155   // end of a step.
156   CompilationRefHolder* MakePerStepRefHolder();
157 
158   // Convenience method called by ~RefHolder without mu_ held. Calls
159   // DiscardEntryRef on every element of entries.
160   void DiscardEntryRefs(gtl::ArraySlice<CompiledSubgraph*> entries);
161 
DebugString()162   std::string DebugString() const override { return "TpuCompilationCacheBase"; }
163 
164  protected:
ConstructCompilationCacheKey(const TpuCompilationCacheKey & key)165   std::string ConstructCompilationCacheKey(const TpuCompilationCacheKey& key) {
166     if (!key.has_guaranteed_const) {
167       return key.prefix;
168     }
169     return absl::StrCat(key.prefix, "|", key.session_handle, "|",
170                         key.guaranteed_const_fingerprint());
171   }
172 
173   // Private implementation of the generic CompilationRefHolder that knows about
174   // CompiledSubgraph entries.
175   class RefHolder : public CompilationRefHolder {
176    public:
177     explicit RefHolder(TpuCompilationCacheInterface* parent);
178     ~RefHolder() override;
179 
180     // Adds entry to the list of entries that will be released when the
181     // RefHolder is destroyed. Each entry is released via a call to
182     // parent_->DiscardEntryRefs.
183     void AddRef(CompiledSubgraph* entry);
184 
185     std::string DebugString() const override;
186 
187    private:
188     TpuCompilationCacheInterface* parent_;  // Not owned.
189     std::vector<CompiledSubgraph*> entries_;
190   };
191 
192   // The bulk of implementation of CompileIfKeyAbsent() with the exception
193   // of unloading programs that corresponds to possibly removed cache
194   // entries. The split helps to manage locking since we prefer to perform
195   // unloading without holding extra locks.
196   Status CompileIfKeyAbsentHelper(
197       const TpuCompilationCacheKey& subgraph_key,
198       const SessionMetadata* session_metadata,
199       CompilationRefHolder* per_step_ref_holder, int64_t* uid,
200       std::vector<std::string>* proto_key,
201       std::vector<std::string>* sharding_key,
202       std::vector<bool>* may_modify_variables,
203       std::vector<CompiledSubgraph*>* removed_entries,
204       absl::Span<const xla::HloProto* const>* hlo_metadatas,
205       const std::function<Status(TpuProgramGroupInterface*)>& compile_function);
206 
207   // This is called by the cache when entry is marked for eviction; by
208   // a RefHolder (via DiscardEntryRefs) when a step completes; and by
209   // an EntryRefImpl when it is destroyed. Releases one reference to entry
210   // if more than 1 remains. If only one reference is left, the entry is removed
211   // from cache_ and is returned to the caller; which must eventually call
212   // UnloadAndDestroy(). We do not call UnloadAndDestroy within DiscardEntryRef
213   // to avoid holding the lock during program unloading.
214   ABSL_MUST_USE_RESULT CompiledSubgraph* DiscardEntryRef(
215       CompiledSubgraph* entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
216 
217   // Marks the oldest unmarked entry for eviction. Requires that there is at
218   // least one such entry. In case the evicted entry had only 1 reference it
219   // is removed from the cache and returned to the caller which must eventually
220   // call UnloadAndDestroy.
221   ABSL_MUST_USE_RESULT CompiledSubgraph* MarkOldestEntryForEviction()
222       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
223 
224   // Updates datastructures to indicate that entry, which had been marked for
225   // eviction, has been looked up. This is called by CompileIfKeyAbsent when an
226   // entry is newly created, or an entry that has been marked for eviction but
227   // not yet evicted is looked up.
228   //
229   // First the entry is unmarked for eviction, i.e. the cache gains a reference
230   // to entry, entry's last_use field is set to be the most recent value of
231   // use_counter_ and entries_by_last_use_ is updated accordingly.
232   //
233   // Next, the size of the cache is examined to see if any other entries need to
234   // be marked for eviction now that entry has been unmarked. While the total
235   // size of unmarked cached entries is greater than max_cache_size_, entries
236   // are marked for eviction in LRU order. The most recently used entry is never
237   // marked for eviction, so an entry larger than the max cache size will remain
238   // in the cache until it is replaced by something else. In case some entries
239   // actually were removed from the cache, they are a returned to the caller via
240   // removed_entries. The caller must eventually delete them by calling
241   // UnloadAndDestroy.
242   void LookupEntryMarkedForEviction(
243       CompiledSubgraph* entry, std::vector<CompiledSubgraph*>* removed_entries)
244       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
245 
246   // Removes the entry with given key from cache.
247   size_t RemoveEntry(const std::string& key) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
248 
249   // Inserts the given key and entry to cache.
250   void InsertEntry(const std::string& key, CompiledSubgraph* entry)
251       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
252 
253   // Returns the cache key matching given subgraph_key.
254   std::string FindCacheKey(const TpuCompilationCacheKey& subgraph_key)
255       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
256 
257   // Creates a new entry by running initialize_programs and places it in the
258   // cache to be looked up by key. The new entry is in the 'marked for eviction'
259   // state (not present in entries_by_last_use_) and the caller is expected to
260   // call LookupEntryMarkedForEviction after InitializeEntry.
261   //
262   // **InitializeEntry releases mu_ during the call to initialize_programs.**
263   virtual CompiledSubgraph* InitializeEntry(
264       const std::string& key,
265       const std::function<Status(TpuProgramGroupInterface*)>&
266           initialize_programs,
267       const TpuCompilationCacheKey& subgraph_key)
268       ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) = 0;
269 
270   // Unloads the program associated with the entry from all local devices
271   // and deletes the entry itself. It is assumed no one else has a reference
272   // to it and all related keys had already been removed from the cache.
273   // The call can perform device IO so no locks should be held while calling it.
274   void UnloadAndDestroy(CompiledSubgraph* entry) ABSL_LOCKS_EXCLUDED(mu_);
275 
276   // The maximum size of entries that are stored in the cache before entries are
277   // marked for eviction.
278   const int64_t max_cache_size_;
279   // Mutex to protect access to shared resources under multi-threading
280   // environment.
281   absl::Mutex mu_;
282   // The total size of entries that are stored and not marked for eviction.
283   int64_t cache_size_ ABSL_GUARDED_BY(mu_) = 0;
284   // The total size of entries that are marked for eviction.
285   int64_t marked_for_eviction_size_ ABSL_GUARDED_BY(mu_) = 0;
286   // The value to assign to the last_use field of the next entry that is looked
287   // up.
288   int64_t use_counter_ ABSL_GUARDED_BY(mu_) = 0;
289   // session_key_map_ and fingerprint_key_map_ are used for looking up the
290   // cache_ key matching a given subgraph key. When doing a lookup, check
291   // session_key_map_ first to avoid unnecessay fingerprint computation.
292   // Map from key prefix + session_handle to a cache_ key.
293   absl::node_hash_map<std::string, std::string> session_key_map_
294       ABSL_GUARDED_BY(mu_);
295   // Map from key prefix + fingerprint to a cache_ key.
296   absl::node_hash_map<std::string, std::string> fingerprint_key_map_
297       ABSL_GUARDED_BY(mu_);
298   // All the subgraph entries that can be looked up in the cache. An entry is
299   // marked for eviction iff it is present in cache_ and not in
300   // entries_by_last_use_.
301   std::unordered_map<std::string, CompiledSubgraph*> cache_
302       ABSL_GUARDED_BY(mu_);
303   // All the subgraph entries that can be looked up in the cache, indexed by
304   // uid.
305   absl::node_hash_map<int64_t, CompiledSubgraph*> entries_by_uid_
306       ABSL_GUARDED_BY(mu_);
307   // All the protos that can be looked up in the cache, indexed by proto
308   // key. The value of the map is a subgraph and the index of the proto compiled
309   // for that subgraph.
310   std::unordered_map<std::string, std::pair<CompiledSubgraph*, int>>
311       entries_by_proto_key_ ABSL_GUARDED_BY(mu_);
312   // Map from last_use to entry, used to mark entries for eviction in LRU
313   // order. If an entry's last_use counter is not present as a key in
314   // entries_by_last_use_ then the entry has been marked for eviction.
315   std::map<int64_t, CompiledSubgraph*> entries_by_last_use_
316       ABSL_GUARDED_BY(mu_);
317 
318   TpuCompilationMetrics tpu_compilation_metrics_;
319 
320  private:
321   TpuCompilationCacheInterface(const TpuCompilationCacheInterface&) = delete;
322   TpuCompilationCacheInterface& operator=(const TpuCompilationCacheInterface&) =
323       delete;
324 };
325 }  // namespace tpu
326 }  // namespace tensorflow
327 
328 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_INTERFACE_H_
329