xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/compiled_subgraph.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/platform/refcount.h"
22 #include "tensorflow/core/platform/status.h"
23 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
24 
25 namespace tensorflow {
26 namespace tpu {
27 
28 // Forward declaration to avoid circular dependency.
29 class TpuCompilationCacheInterface;
30 
31 // Cache for compiled TPU program.
32 //
33 // Each key identifies a unique subgraph, and the value is the vector of
34 // protos that are emitted by compiling the subgraph.
35 //
36 // When a subgraph is considered for compilation, the client calls
37 //
38 // auto subgraph_key = <compute key for subgraph>;
39 // auto compile_function = <lambda to compile subgraph into protos>;
40 // auto per_step_ref_holder = <container to control lifetime of cached
41 // results>;
42 // int64 uid;
43 // std::vector<string> proto_key;
44 // CompileIfKeyAbsent(subgraph_key, per_step_ref_holder, &uid, &proto_key,
45 //                    compile_function);
46 //
47 // where subgraph_key is the key computed for the subgraph. On success,
48 // proto_key contains a vector of keys, where proto_key[i] can be used to look
49 // up the ith proto compiled from the subgraph, and uid contains an identifier
50 // that can be used in place of key for clients that require cheap
51 // serializable handles. If the compiled protos were not present in the cache,
52 // compile_function would be called to generate them. per_step_ref_holder
53 // extends the lifetime of cached results: it is guaranteed that the protos
54 // indicated in proto_key will be available for lookup for at least as long as
55 // per_step_ref_holder is not deleted.
56 //
57 // If the caller passes nullptr instead of a per_step_ref_holder then the
58 // caller is responsible for calling Release(subgraph_key) once for every call
59 // to CompileIfKeyAbsent(subgraph_key, ...) to discard the reference to the
60 // compilation results, after the caller is sure it will not look up the
61 // compiled executables again.
62 //
63 // Subsequently the client can call
64 //
65 // std::unique_ptr<CompilationCacheEntryRef> entry;
66 // Lookup(proto_key, &entry);
67 // auto proto = entry->get();
68 //
69 // or
70 //
71 // std::unique_ptr<CompilationCacheEntryRef> entry;
72 // Lookup(uid, proto_index, &entry);
73 // auto proto = entry->get();
74 //
75 // to access a cached proto.
76 // TODO(misard) Switch the existing TPU ops to use uid+proto_index instead of
77 // string keys for proto lookups.
78 //
79 //
80 // Usage details within the system:
81 //
82 // This cache lives in the resource manager of the TPU_SYSTEM device where the
83 // compiler runs, typically worker 0 of the system. The cache is discarded and
84 // a new one created whenever the system is reinitialized.
85 //
86 // A compiled subgraph is placed into the cache using a key that is a
87 // combination of the function name, guaranteed_constants, the shapes of the
88 // dynamic inputs to the subgraph, and the function library in use at the time
89 // of execution.
90 //
91 // Whenever a compile Op is run, it looks to see if there is already an entry
92 // in the cache corresponding to that Op and the current dynamic shapes, and
93 // creates one if not. The entry is marked as most recently used in the cache
94 // by the compile Op. The entry is reference counted. The cache owns one entry
95 // , and each step that has executed a compile Op referring to the entry owns
96 // a reference until that step completes.
97 //
98 // If the cache exceeds a configured storage limit, entries are marked for
99 // eviction in order of least recently used. An entry is not evicted until all
100 // references to it are discarded, so an entry that is marked for eviction can
101 // still be looked up by the execute Ops in a running step. If another Compile
102 // Op looks up an entry that is marked for eviction, the entry will be
103 // unmarked and set to most recently used.
104 //
105 struct CompiledSubgraph : public core::RefCounted {
106   TpuCompilationCacheInterface* parent = nullptr;  // Not owned.
107 
108   bool initialized = false;
109 
110   // The Status returned by the compilation function when the entry is
111   // initialized. This status will be returned to any client that requests the
112   // entry.
113   Status initialization_status;
114 
115   // Counter to keep track of LRU entries for the eviction policy.
116   int64_t last_use = -1;
117 
118   // The unique key describing this entry.
119   std::string subgraph_key;
120 
121   // The uid describing this entry.
122   int64_t uid;
123 
124   // Compilation cache proto key to identify the cache entry.
125   std::vector<std::string> proto_key;
126 
127   // Fingerprints of sharding programs if there is any.
128   std::vector<std::string> sharding_key;
129 
130   // The number of 'external' client-held references to the entry.
131   int external_references = 0;
132 
133   // The sum of the SpaceUsed of each of the elements of programs; an estimate
134   // of how much RAM the entry consumes, used to determine when entries must
135   // be marked for eviction.
136   int64_t total_size = 0;
137 
138   // Debug info in case we miss.
139   std::string cache_entry_debug_string;
140 
141   // Entries representing the associated sharding and unsharding programs,
142   // which share the same life time of the owning main entry, so we always use
143   // the main entry's ref count.
144   std::unique_ptr<CompiledSubgraph> sharding_entry;
145   std::unique_ptr<CompiledSubgraph> unsharding_entry;
146 
147   // Only used for the nested sharding/unsharding entries to point to the
148   // owning main entry.
149   CompiledSubgraph* main_entry = nullptr;
150 
151   // Compiled TPU program group.
152   std::unique_ptr<TpuProgramGroupInterface> tpu_program_group;
153 
154   // Computes total program size.
ComputeTotalSizeCompiledSubgraph155   size_t ComputeTotalSize() const {
156     CHECK_EQ(total_size, 0);
157     int64_t size = tpu_program_group->program_size();
158 
159     if (sharding_entry != nullptr) {
160       size += sharding_entry->total_size;
161     }
162     if (unsharding_entry != nullptr) {
163       size += unsharding_entry->total_size;
164     }
165     return size;
166   }
167 };
168 
169 }  // namespace tpu
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_
173