1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 18 19 #include <memory> 20 #include <string> 21 #include <utility> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/container/inlined_vector.h" 25 #include "absl/types/optional.h" 26 #include "absl/types/span.h" 27 #include "absl/types/variant.h" 28 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h" 29 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 30 #include "tensorflow/compiler/tf2xla/xla_context.h" 31 #include "tensorflow/compiler/xla/client/local_client.h" 32 #include "tensorflow/compiler/xla/service/hlo.pb.h" 33 #include "tensorflow/compiler/xla/statusor.h" 34 #include "tensorflow/core/common_runtime/device.h" 35 #include "tensorflow/core/common_runtime/device_mgr.h" 36 #include "tensorflow/core/framework/graph.pb.h" 37 #include "tensorflow/core/framework/op_kernel.h" 38 #include "tensorflow/core/lib/core/threadpool.h" 39 #include "tensorflow/core/platform/mutex.h" 40 #include "tensorflow/core/platform/thread_annotations.h" 41 #include "tensorflow/core/protobuf/meta_graph.pb.h" 42 43 namespace tensorflow { 44 45 // The XlaCompilationCache class caches the results of the XlaCompiler class, 46 // which converts a Tensorflow graph into a compiled XLA compilation. 47 // 48 // Since XLA computations must have static shapes, the cache generates a new 49 // XLA computation for each new set of input shapes. 50 // 51 // Currently no cache eviction policy is implemented and the cache grows without 52 // bound. 53 class XlaCompilationCache : public ResourceBase { 54 public: 55 struct Config { ConfigConfig56 Config() {} ConfigConfig57 explicit Config(absl::string_view persistent_cache_directory, 58 bool disable_strict_signature_checks, 59 absl::string_view persistance_prefix) 60 : persistent_cache_directory(persistent_cache_directory), 61 disable_strict_signature_checks(disable_strict_signature_checks), 62 persistance_prefix(persistance_prefix) {} 63 64 // If non-empty, JIT-compiled executables are saved to and loaded from the 65 // specified file system directory path. 66 std::string persistent_cache_directory; 67 68 // Disable strict signature checks for entries loaded into the cache from 69 // external sources. 70 bool disable_strict_signature_checks = false; 71 72 // The cache persistence prefix to use if serializing/deserialzing entries. 73 std::string persistance_prefix; 74 }; 75 XlaCompilationCache(Config config, xla::LocalClient* client, 76 DeviceType device_type); 77 ~XlaCompilationCache() override; 78 79 enum class CompileMode { 80 kLazy, 81 kStrict, 82 kAsync, 83 }; 84 85 enum class CompileState { kUncompiled, kCompiling, kCompiled }; 86 87 enum class CompileScope { 88 kOp, 89 kFunction, 90 }; 91 92 // Compiles a function into a XlaCompiler::CompilationResult that can be used 93 // to execute an XLA Computation. Compilation results are cached. 94 // `function` is the name of a Tensorflow function to compile. 95 // `args` is a description of the arguments to the computation. 96 // 97 // `compile_mode` controls the behavior of the compilation cache on a cache 98 // miss. If `compile_mode` is `kLazy` then, based on some profitability 99 // heuristics, the compilation cache may decide not to compile the cluster at 100 // this time. In this case it returns null into both `out_compilation_result` 101 // and `out_executable`. If `compile_mode` is `kStrict` then the compilation 102 // cache always attempts the compilation on a cache miss. If compilation mode 103 // is 'kAsync' compilation of the cluster happens in the background while the 104 // fallback path executes. 105 // 106 // The result of compilation is written to `*out_compilation_result`, which 107 // must be non-null. If `out_executable` is non-null, also builds an 108 // xla::LocalExecutable and sets `out_executable` to point to it. The 109 // resulting executable pointer may be null if the computation has no 110 // non-constant outputs. 111 Status Compile(const XlaCompiler::Options& options, 112 const NameAttrList& function, 113 const std::vector<XlaCompiler::Argument>& args, 114 const XlaCompiler::CompileOptions& compile_options, 115 CompileMode compile_mode, 116 const XlaCompiler::CompilationResult** out_compilation_result, 117 xla::LocalExecutable** out_executable); 118 119 // As above, but calls XlaCompiler::CompileSingleOp instead of 120 // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto 121 // in OpKernelContext, then uses MLIR bridge for compilation instead of 122 // XlaCompiler, if possible. 123 Status CompileSingleOp( 124 const XlaCompiler::Options& options, 125 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx, 126 const XlaCompiler::CompileOptions& compile_options, 127 const XlaCompiler::CompilationResult** out_compilation_result, 128 xla::LocalExecutable** out_executable); 129 client()130 xla::LocalClient* client() const { return client_; } device_type()131 const DeviceType& device_type() const { return device_type_; } 132 133 string DebugString() const override; 134 135 // Describes the types, shapes and any compile-time constant arguments 136 // to a kernel. Key that uniquely identifies a compilation output. 137 struct Signature { 138 string name; 139 140 // List of args (either as a TensorTypeAndShape or as a Tensor value) 141 // for compile-time constant arguments to the compilation, ordered by 142 // argument number. Tensors must be in host memory. 143 using TensorTypeAndShape = 144 std::pair<DataType, absl::InlinedVector<int64_t, 4>>; 145 absl::InlinedVector<absl::variant<Tensor, TensorTypeAndShape>, 8> args; 146 147 bool operator==(const Signature& other) const; 148 149 struct Hash { 150 uint64 operator()(const Signature& signature) const; 151 }; 152 153 // Returns a human-readable description of the signature. 154 string HumanString() const; 155 }; 156 157 // Builds the signature for a compilation. 158 static StatusOr<Signature> BuildSignature( 159 const NameAttrList& function, 160 absl::Span<const XlaCompiler::Argument> args); 161 162 private: 163 // Common implementation of Compile and CompileSingleOp. The `OpKernelContext` 164 // parameter is always null for the former. 165 Status CompileImpl( 166 const XlaCompiler::CompileOptions& compile_options, 167 const XlaCompiler::Options& options, const NameAttrList& function, 168 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx, 169 CompileScope scope, CompileMode compile_mode, 170 const XlaCompiler::CompilationResult** out_compilation_result, 171 xla::LocalExecutable** out_executable); 172 173 // Takes `result` which has been compiled from a Tensorflow subgraph to a 174 // XLA computation already, and generates an XLA LocalExecutable `executable`. 175 Status BuildExecutable(const XlaCompiler::Options& options, 176 const XlaCompiler::CompilationResult& result, 177 std::unique_ptr<xla::LocalExecutable>* executable); 178 179 // Like BuildExecutable above, except that it generates an XLA 180 // AotCompilationResult (instead of LocalExecutable), which can be persisted 181 // to later load a LocalExecutable using the LoadExecutable() method below. 182 StatusOr<std::unique_ptr<xla::AotCompilationResult>> 183 BuildSerializedExecutable(const XlaCompiler::Options& options, 184 const XlaCompiler::CompilationResult& result); 185 186 // Returns an XLA LocalExecutable loaded from a serialized XLA 187 // AotCompilationResult. 188 StatusOr<std::unique_ptr<xla::LocalExecutable>> LoadExecutable( 189 const XlaCompiler::Options& options, 190 const XlaCompiler::CompilationResult& result, 191 const std::string& serialized_aot_result); 192 193 // Determines whether the cluster should be compiled. 194 bool ShouldCompileCluster(CompileMode compile_mode, bool is_first_execution, 195 int64_t current_request_count, 196 const NameAttrList& function); 197 198 xla::LocalClient* const client_; 199 const DeviceType device_type_; 200 bool disable_strict_signature_checks_; 201 std::string persistance_prefix_; 202 203 // The value associated with a cache entry. 204 struct Entry { 205 mutex mu; 206 207 // The current compilation state for this entry. 208 CompileState compile_state = CompileState::kUncompiled; 209 210 // The number of times a compilation with this signature has been requested. 211 int64_t request_count = 0; 212 213 // Did compilation succeed? 214 Status compilation_status TF_GUARDED_BY(mu); 215 216 // Output of the XlaCompiler. 217 XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu); 218 219 // The XLA executable compiled from <computation>. May be null if no 220 // executable has been built. 221 std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu); 222 }; 223 224 // Returns a cache key proto that identifies an entry in the compilation 225 // cache. 226 XlaSerializedCacheKey BuildSerializedCacheKey( 227 const Signature& sig, const xla::HloModuleProto& hlo_module) const; 228 229 // Serializes the signature and its corresponding entry to a proto message. 230 StatusOr<XlaSerializedCacheEntry> SerializeEntry( 231 const XlaCompiler::Options& options, const Signature& sig, 232 const Entry& entry) TF_EXCLUSIVE_LOCKS_REQUIRED(entry.mu); 233 234 // Checks if the loaded `entry` matches the expected `key` and `hlo_module`. 235 Status VerifyLoadedCacheEntry(const XlaSerializedCacheKey& key, 236 const xla::HloModuleProto& hlo_module, 237 const XlaSerializedCacheEntry& entry); 238 239 Status CompileStrict(const Signature& sig, Entry* entry, 240 const XlaCompiler::CompileOptions& compile_options, 241 const XlaCompiler::Options& options, 242 const std::vector<XlaCompiler::Argument>& args, 243 const NameAttrList& function, OpKernelContext* ctx, 244 CompileScope scope) 245 TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu); 246 Status CompileAsynchronous(const Signature& sig, Entry* entry, 247 const XlaCompiler::CompileOptions& compile_options, 248 const XlaCompiler::Options& options, 249 const std::vector<XlaCompiler::Argument>& args, 250 const NameAttrList& function, OpKernelContext* ctx, 251 CompileScope scope); 252 253 // Saves the cache entry in the file directory supplied during the 254 // construction of this class. Overwrites existing entries. 255 Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry); 256 257 // Tries to load a cache entry given a `key` by searching the file directory 258 // supplied during the construction of this class. Returns std::nullopt if no 259 // cache entry is found. 260 StatusOr<std::optional<XlaSerializedCacheEntry>> TryLoadSerializedEntry( 261 const XlaSerializedCacheKey& key); 262 263 mutex compile_cache_mu_; 264 absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_ 265 TF_GUARDED_BY(compile_cache_mu_); 266 267 struct ClusterCompileStats { 268 // Number of times the cluster has been (re-)compiled. 269 int64_t compile_count = 0; 270 271 // The number of times this cluster has been executed. 272 int64_t execution_count = 0; 273 274 // Cumulative time spent compiling the cluster. 275 int64_t cumulative_compile_time_us = 0; 276 }; 277 278 mutex cluster_compile_stats_mu_; 279 280 // Maps cluster names to compilation statistics for said cluster. 281 absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_ 282 TF_GUARDED_BY(cluster_compile_stats_mu_); 283 284 struct AsyncCompilationState { 285 mutex async_compilation_state_mu; 286 287 // Number of threads for asynchronous compilations. 288 static constexpr int64_t kNumCompilerThreads = 10; 289 290 // Maximum number of ongoing compilations. 291 static constexpr int64_t kMaxNumOngoingCompilations = kNumCompilerThreads; 292 293 // Number of ongoing compilations. 294 int64_t num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) = 295 0; 296 297 // Pool of threads for asynchronous compilations. 298 std::unique_ptr<thread::ThreadPool> compiler_threads; 299 AsyncCompilationStateAsyncCompilationState300 AsyncCompilationState() { 301 compiler_threads = std::make_unique<tensorflow::thread::ThreadPool>( 302 tensorflow::Env::Default(), "async_compiler_threads", 303 kNumCompilerThreads); 304 } 305 306 } async_compilation_state_; 307 308 // The number of times a lazy compilation must be requested for a specific 309 // signature before we attempt to compile it. 310 static constexpr int64_t kDefaultCompilationThreshold = 2; 311 312 // If non-empty, JIT-compiled executables are saved to and loaded from the 313 // specified file system directory path. 314 std::string persistent_cache_directory_; 315 316 TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache); 317 }; 318 319 // Creates a single-node graph using the specified node_def as the only op apart 320 // from the arg and retval nodes. 321 StatusOr<std::unique_ptr<Graph>> CreateGraph( 322 const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args, 323 absl::Span<const DataType> result_types); 324 325 // Use XlaCompiler to compile a single op into HLO. 326 Status XlaSingleOpToHlo( 327 XlaCompiler* compiler, const XlaCompiler::Options& options, 328 const std::vector<XlaCompiler::Argument>& args, 329 const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument, 330 const XlaCompiler::CompileOptions& compile_options, 331 XlaCompiler::CompilationResult* compilation_result); 332 333 } // namespace tensorflow 334 335 #endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_ 336