xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/jit/xla_compilation_cache.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/inlined_vector.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "absl/types/variant.h"
28 #include "tensorflow/compiler/jit/xla_compilation_cache.pb.h"
29 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
30 #include "tensorflow/compiler/tf2xla/xla_context.h"
31 #include "tensorflow/compiler/xla/client/local_client.h"
32 #include "tensorflow/compiler/xla/service/hlo.pb.h"
33 #include "tensorflow/compiler/xla/statusor.h"
34 #include "tensorflow/core/common_runtime/device.h"
35 #include "tensorflow/core/common_runtime/device_mgr.h"
36 #include "tensorflow/core/framework/graph.pb.h"
37 #include "tensorflow/core/framework/op_kernel.h"
38 #include "tensorflow/core/lib/core/threadpool.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41 #include "tensorflow/core/protobuf/meta_graph.pb.h"
42 
43 namespace tensorflow {
44 
45 // The XlaCompilationCache class caches the results of the XlaCompiler class,
46 // which converts a Tensorflow graph into a compiled XLA compilation.
47 //
48 // Since XLA computations must have static shapes, the cache generates a new
49 // XLA computation for each new set of input shapes.
50 //
51 // Currently no cache eviction policy is implemented and the cache grows without
52 // bound.
53 class XlaCompilationCache : public ResourceBase {
54  public:
55   struct Config {
ConfigConfig56     Config() {}
ConfigConfig57     explicit Config(absl::string_view persistent_cache_directory,
58                     bool disable_strict_signature_checks,
59                     absl::string_view persistance_prefix)
60         : persistent_cache_directory(persistent_cache_directory),
61           disable_strict_signature_checks(disable_strict_signature_checks),
62           persistance_prefix(persistance_prefix) {}
63 
64     // If non-empty, JIT-compiled executables are saved to and loaded from the
65     // specified file system directory path.
66     std::string persistent_cache_directory;
67 
68     // Disable strict signature checks for entries loaded into the cache from
69     // external sources.
70     bool disable_strict_signature_checks = false;
71 
72     // The cache persistence prefix to use if serializing/deserialzing entries.
73     std::string persistance_prefix;
74   };
75   XlaCompilationCache(Config config, xla::LocalClient* client,
76                       DeviceType device_type);
77   ~XlaCompilationCache() override;
78 
79   enum class CompileMode {
80     kLazy,
81     kStrict,
82     kAsync,
83   };
84 
85   enum class CompileState { kUncompiled, kCompiling, kCompiled };
86 
87   enum class CompileScope {
88     kOp,
89     kFunction,
90   };
91 
92   // Compiles a function into a XlaCompiler::CompilationResult that can be used
93   // to execute an XLA Computation. Compilation results are cached.
94   // `function` is the name of a Tensorflow function to compile.
95   // `args` is a description of the arguments to the computation.
96   //
97   // `compile_mode` controls the behavior of the compilation cache on a cache
98   // miss.  If `compile_mode` is `kLazy` then, based on some profitability
99   // heuristics, the compilation cache may decide not to compile the cluster at
100   // this time.  In this case it returns null into both `out_compilation_result`
101   // and `out_executable`.  If `compile_mode` is `kStrict` then the compilation
102   // cache always attempts the compilation on a cache miss. If compilation mode
103   // is 'kAsync' compilation of the cluster happens in the background while the
104   // fallback path executes.
105   //
106   // The result of compilation is written to `*out_compilation_result`, which
107   // must be non-null. If `out_executable` is non-null, also builds an
108   // xla::LocalExecutable and sets `out_executable` to point to it. The
109   // resulting executable pointer may be null if the computation has no
110   // non-constant outputs.
111   Status Compile(const XlaCompiler::Options& options,
112                  const NameAttrList& function,
113                  const std::vector<XlaCompiler::Argument>& args,
114                  const XlaCompiler::CompileOptions& compile_options,
115                  CompileMode compile_mode,
116                  const XlaCompiler::CompilationResult** out_compilation_result,
117                  xla::LocalExecutable** out_executable);
118 
119   // As above, but calls XlaCompiler::CompileSingleOp instead of
120   // XlaCompiler::CompileFunction. If MLIR bridge is enabled through ConfigProto
121   // in OpKernelContext, then uses MLIR bridge for compilation instead of
122   // XlaCompiler, if possible.
123   Status CompileSingleOp(
124       const XlaCompiler::Options& options,
125       const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
126       const XlaCompiler::CompileOptions& compile_options,
127       const XlaCompiler::CompilationResult** out_compilation_result,
128       xla::LocalExecutable** out_executable);
129 
client()130   xla::LocalClient* client() const { return client_; }
device_type()131   const DeviceType& device_type() const { return device_type_; }
132 
133   string DebugString() const override;
134 
135   // Describes the types, shapes and any compile-time constant arguments
136   // to a kernel. Key that uniquely identifies a compilation output.
137   struct Signature {
138     string name;
139 
140     // List of args (either as a TensorTypeAndShape or as a Tensor value)
141     // for compile-time constant arguments to the compilation, ordered by
142     // argument number. Tensors must be in host memory.
143     using TensorTypeAndShape =
144         std::pair<DataType, absl::InlinedVector<int64_t, 4>>;
145     absl::InlinedVector<absl::variant<Tensor, TensorTypeAndShape>, 8> args;
146 
147     bool operator==(const Signature& other) const;
148 
149     struct Hash {
150       uint64 operator()(const Signature& signature) const;
151     };
152 
153     // Returns a human-readable description of the signature.
154     string HumanString() const;
155   };
156 
157   // Builds the signature for a compilation.
158   static StatusOr<Signature> BuildSignature(
159       const NameAttrList& function,
160       absl::Span<const XlaCompiler::Argument> args);
161 
162  private:
163   // Common implementation of Compile and CompileSingleOp. The `OpKernelContext`
164   // parameter is always null for the former.
165   Status CompileImpl(
166       const XlaCompiler::CompileOptions& compile_options,
167       const XlaCompiler::Options& options, const NameAttrList& function,
168       const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
169       CompileScope scope, CompileMode compile_mode,
170       const XlaCompiler::CompilationResult** out_compilation_result,
171       xla::LocalExecutable** out_executable);
172 
173   // Takes `result` which has been compiled from a Tensorflow subgraph to a
174   // XLA computation already, and generates an XLA LocalExecutable `executable`.
175   Status BuildExecutable(const XlaCompiler::Options& options,
176                          const XlaCompiler::CompilationResult& result,
177                          std::unique_ptr<xla::LocalExecutable>* executable);
178 
179   // Like BuildExecutable above, except that it generates an XLA
180   // AotCompilationResult (instead of LocalExecutable), which can be persisted
181   // to later load a LocalExecutable using the LoadExecutable() method below.
182   StatusOr<std::unique_ptr<xla::AotCompilationResult>>
183   BuildSerializedExecutable(const XlaCompiler::Options& options,
184                             const XlaCompiler::CompilationResult& result);
185 
186   // Returns an XLA LocalExecutable loaded from a serialized XLA
187   // AotCompilationResult.
188   StatusOr<std::unique_ptr<xla::LocalExecutable>> LoadExecutable(
189       const XlaCompiler::Options& options,
190       const XlaCompiler::CompilationResult& result,
191       const std::string& serialized_aot_result);
192 
193   // Determines whether the cluster should be compiled.
194   bool ShouldCompileCluster(CompileMode compile_mode, bool is_first_execution,
195                             int64_t current_request_count,
196                             const NameAttrList& function);
197 
198   xla::LocalClient* const client_;
199   const DeviceType device_type_;
200   bool disable_strict_signature_checks_;
201   std::string persistance_prefix_;
202 
203   // The value associated with a cache entry.
204   struct Entry {
205     mutex mu;
206 
207     // The current compilation state for this entry.
208     CompileState compile_state = CompileState::kUncompiled;
209 
210     // The number of times a compilation with this signature has been requested.
211     int64_t request_count = 0;
212 
213     // Did compilation succeed?
214     Status compilation_status TF_GUARDED_BY(mu);
215 
216     // Output of the XlaCompiler.
217     XlaCompiler::CompilationResult compilation_result TF_GUARDED_BY(mu);
218 
219     // The XLA executable compiled from <computation>. May be null if no
220     // executable has been built.
221     std::unique_ptr<xla::LocalExecutable> executable TF_GUARDED_BY(mu);
222   };
223 
224   // Returns a cache key proto that identifies an entry in the compilation
225   // cache.
226   XlaSerializedCacheKey BuildSerializedCacheKey(
227       const Signature& sig, const xla::HloModuleProto& hlo_module) const;
228 
229   // Serializes the signature and its corresponding entry to a proto message.
230   StatusOr<XlaSerializedCacheEntry> SerializeEntry(
231       const XlaCompiler::Options& options, const Signature& sig,
232       const Entry& entry) TF_EXCLUSIVE_LOCKS_REQUIRED(entry.mu);
233 
234   // Checks if the loaded `entry` matches the expected `key` and `hlo_module`.
235   Status VerifyLoadedCacheEntry(const XlaSerializedCacheKey& key,
236                                 const xla::HloModuleProto& hlo_module,
237                                 const XlaSerializedCacheEntry& entry);
238 
239   Status CompileStrict(const Signature& sig, Entry* entry,
240                        const XlaCompiler::CompileOptions& compile_options,
241                        const XlaCompiler::Options& options,
242                        const std::vector<XlaCompiler::Argument>& args,
243                        const NameAttrList& function, OpKernelContext* ctx,
244                        CompileScope scope)
245       TF_EXCLUSIVE_LOCKS_REQUIRED(entry->mu);
246   Status CompileAsynchronous(const Signature& sig, Entry* entry,
247                              const XlaCompiler::CompileOptions& compile_options,
248                              const XlaCompiler::Options& options,
249                              const std::vector<XlaCompiler::Argument>& args,
250                              const NameAttrList& function, OpKernelContext* ctx,
251                              CompileScope scope);
252 
253   // Saves the cache entry in the file directory supplied during the
254   // construction of this class. Overwrites existing entries.
255   Status SaveSerializedEntry(const XlaSerializedCacheEntry& entry);
256 
257   // Tries to load a cache entry given a `key` by searching the file directory
258   // supplied during the construction of this class. Returns std::nullopt if no
259   // cache entry is found.
260   StatusOr<std::optional<XlaSerializedCacheEntry>> TryLoadSerializedEntry(
261       const XlaSerializedCacheKey& key);
262 
263   mutex compile_cache_mu_;
264   absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
265       TF_GUARDED_BY(compile_cache_mu_);
266 
267   struct ClusterCompileStats {
268     // Number of times the cluster has been (re-)compiled.
269     int64_t compile_count = 0;
270 
271     // The number of times this cluster has been executed.
272     int64_t execution_count = 0;
273 
274     // Cumulative time spent compiling the cluster.
275     int64_t cumulative_compile_time_us = 0;
276   };
277 
278   mutex cluster_compile_stats_mu_;
279 
280   // Maps cluster names to compilation statistics for said cluster.
281   absl::flat_hash_map<string, ClusterCompileStats> cluster_compile_stats_
282       TF_GUARDED_BY(cluster_compile_stats_mu_);
283 
284   struct AsyncCompilationState {
285     mutex async_compilation_state_mu;
286 
287     // Number of threads for asynchronous compilations.
288     static constexpr int64_t kNumCompilerThreads = 10;
289 
290     // Maximum number of ongoing compilations.
291     static constexpr int64_t kMaxNumOngoingCompilations = kNumCompilerThreads;
292 
293     // Number of ongoing compilations.
294     int64_t num_ongoing_compilations TF_GUARDED_BY(async_compilation_state_mu) =
295         0;
296 
297     // Pool of threads for asynchronous compilations.
298     std::unique_ptr<thread::ThreadPool> compiler_threads;
299 
AsyncCompilationStateAsyncCompilationState300     AsyncCompilationState() {
301       compiler_threads = std::make_unique<tensorflow::thread::ThreadPool>(
302           tensorflow::Env::Default(), "async_compiler_threads",
303           kNumCompilerThreads);
304     }
305 
306   } async_compilation_state_;
307 
308   // The number of times a lazy compilation must be requested for a specific
309   // signature before  we attempt to compile it.
310   static constexpr int64_t kDefaultCompilationThreshold = 2;
311 
312   // If non-empty, JIT-compiled executables are saved to and loaded from the
313   // specified file system directory path.
314   std::string persistent_cache_directory_;
315 
316   TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
317 };
318 
319 // Creates a single-node graph using the specified node_def as the only op apart
320 // from the arg and retval nodes.
321 StatusOr<std::unique_ptr<Graph>> CreateGraph(
322     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
323     absl::Span<const DataType> result_types);
324 
325 // Use XlaCompiler to compile a single op into HLO.
326 Status XlaSingleOpToHlo(
327     XlaCompiler* compiler, const XlaCompiler::Options& options,
328     const std::vector<XlaCompiler::Argument>& args,
329     const XlaCompiler::SingleOpCompileArgument& single_op_compile_argument,
330     const XlaCompiler::CompileOptions& compile_options,
331     XlaCompiler::CompilationResult* compilation_result);
332 
333 }  // namespace tensorflow
334 
335 #endif  // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
336