1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_ 17 18 #include <memory> 19 #include <string> 20 21 #include "tensorflow/core/platform/refcount.h" 22 #include "tensorflow/core/platform/status.h" 23 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" 24 25 namespace tensorflow { 26 namespace tpu { 27 28 // Forward declaration to avoid circular dependency. 29 class TpuCompilationCacheInterface; 30 31 // Cache for compiled TPU program. 32 // 33 // Each key identifies a unique subgraph, and the value is the vector of 34 // protos that are emitted by compiling the subgraph. 35 // 36 // When a subgraph is considered for compilation, the client calls 37 // 38 // auto subgraph_key = <compute key for subgraph>; 39 // auto compile_function = <lambda to compile subgraph into protos>; 40 // auto per_step_ref_holder = <container to control lifetime of cached 41 // results>; 42 // int64 uid; 43 // std::vector<string> proto_key; 44 // CompileIfKeyAbsent(subgraph_key, per_step_ref_holder, &uid, &proto_key, 45 // compile_function); 46 // 47 // where subgraph_key is the key computed for the subgraph. On success, 48 // proto_key contains a vector of keys, where proto_key[i] can be used to look 49 // up the ith proto compiled from the subgraph, and uid contains an identifier 50 // that can be used in place of key for clients that require cheap 51 // serializable handles. If the compiled protos were not present in the cache, 52 // compile_function would be called to generate them. per_step_ref_holder 53 // extends the lifetime of cached results: it is guaranteed that the protos 54 // indicated in proto_key will be available for lookup for at least as long as 55 // per_step_ref_holder is not deleted. 56 // 57 // If the caller passes nullptr instead of a per_step_ref_holder then the 58 // caller is responsible for calling Release(subgraph_key) once for every call 59 // to CompileIfKeyAbsent(subgraph_key, ...) to discard the reference to the 60 // compilation results, after the caller is sure it will not look up the 61 // compiled executables again. 62 // 63 // Subsequently the client can call 64 // 65 // std::unique_ptr<CompilationCacheEntryRef> entry; 66 // Lookup(proto_key, &entry); 67 // auto proto = entry->get(); 68 // 69 // or 70 // 71 // std::unique_ptr<CompilationCacheEntryRef> entry; 72 // Lookup(uid, proto_index, &entry); 73 // auto proto = entry->get(); 74 // 75 // to access a cached proto. 76 // TODO(misard) Switch the existing TPU ops to use uid+proto_index instead of 77 // string keys for proto lookups. 78 // 79 // 80 // Usage details within the system: 81 // 82 // This cache lives in the resource manager of the TPU_SYSTEM device where the 83 // compiler runs, typically worker 0 of the system. The cache is discarded and 84 // a new one created whenever the system is reinitialized. 85 // 86 // A compiled subgraph is placed into the cache using a key that is a 87 // combination of the function name, guaranteed_constants, the shapes of the 88 // dynamic inputs to the subgraph, and the function library in use at the time 89 // of execution. 90 // 91 // Whenever a compile Op is run, it looks to see if there is already an entry 92 // in the cache corresponding to that Op and the current dynamic shapes, and 93 // creates one if not. The entry is marked as most recently used in the cache 94 // by the compile Op. The entry is reference counted. The cache owns one entry 95 // , and each step that has executed a compile Op referring to the entry owns 96 // a reference until that step completes. 97 // 98 // If the cache exceeds a configured storage limit, entries are marked for 99 // eviction in order of least recently used. An entry is not evicted until all 100 // references to it are discarded, so an entry that is marked for eviction can 101 // still be looked up by the execute Ops in a running step. If another Compile 102 // Op looks up an entry that is marked for eviction, the entry will be 103 // unmarked and set to most recently used. 104 // 105 struct CompiledSubgraph : public core::RefCounted { 106 TpuCompilationCacheInterface* parent = nullptr; // Not owned. 107 108 bool initialized = false; 109 110 // The Status returned by the compilation function when the entry is 111 // initialized. This status will be returned to any client that requests the 112 // entry. 113 Status initialization_status; 114 115 // Counter to keep track of LRU entries for the eviction policy. 116 int64_t last_use = -1; 117 118 // The unique key describing this entry. 119 std::string subgraph_key; 120 121 // The uid describing this entry. 122 int64_t uid; 123 124 // Compilation cache proto key to identify the cache entry. 125 std::vector<std::string> proto_key; 126 127 // Fingerprints of sharding programs if there is any. 128 std::vector<std::string> sharding_key; 129 130 // The number of 'external' client-held references to the entry. 131 int external_references = 0; 132 133 // The sum of the SpaceUsed of each of the elements of programs; an estimate 134 // of how much RAM the entry consumes, used to determine when entries must 135 // be marked for eviction. 136 int64_t total_size = 0; 137 138 // Debug info in case we miss. 139 std::string cache_entry_debug_string; 140 141 // Entries representing the associated sharding and unsharding programs, 142 // which share the same life time of the owning main entry, so we always use 143 // the main entry's ref count. 144 std::unique_ptr<CompiledSubgraph> sharding_entry; 145 std::unique_ptr<CompiledSubgraph> unsharding_entry; 146 147 // Only used for the nested sharding/unsharding entries to point to the 148 // owning main entry. 149 CompiledSubgraph* main_entry = nullptr; 150 151 // Compiled TPU program group. 152 std::unique_ptr<TpuProgramGroupInterface> tpu_program_group; 153 154 // Computes total program size. ComputeTotalSizeCompiledSubgraph155 size_t ComputeTotalSize() const { 156 CHECK_EQ(total_size, 0); 157 int64_t size = tpu_program_group->program_size(); 158 159 if (sharding_entry != nullptr) { 160 size += sharding_entry->total_size; 161 } 162 if (unsharding_entry != nullptr) { 163 size += unsharding_entry->total_size; 164 } 165 return size; 166 } 167 }; 168 169 } // namespace tpu 170 } // namespace tensorflow 171 172 #endif // TENSORFLOW_CORE_TPU_KERNELS_COMPILED_SUBGRAPH_H_ 173