xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/kernel_cache.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/Export.h>
4 #include <torch/csrc/jit/codegen/fuser/kernel_spec.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 
7 #include <cstdint>
8 #include <functional>
9 #include <optional>
10 
11 namespace torch::jit::fuser {
12 
13 // A thread-safe cache interface.
14 
15 // Normalizes the graph by canonicalizing and erasing shape information
16 TORCH_API std::shared_ptr<Graph> normalizeGraphForCache(
17     const std::shared_ptr<Graph>& graph);
18 
19 // Stores the given graph, returning the key used to access it
20 TORCH_API int64_t store(std::shared_ptr<Graph> graph);
21 
22 // Given a graph, find a KernelSpec based on it
23 TORCH_API std::optional<KernelSpec*> lookupGraph(
24     const std::shared_ptr<Graph>& graph);
25 
26 // Returns the graph corresponding to the given key (if it exists)
27 TORCH_API std::optional<KernelSpec*> retrieve(const int64_t key);
28 
29 // Returns the size of the fusion key -> KernelSpec cache.
30 // Only used for testing.
31 TORCH_API int64_t debugNumCachedKernelSpecs();
32 
33 } // namespace torch::jit::fuser
34