1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/codegen/fuser/kernel_spec.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 7 #include <cstdint> 8 #include <functional> 9 #include <optional> 10 11 namespace torch::jit::fuser { 12 13 // A thread-safe cache interface. 14 15 // Normalizes the graph by canonicalizing and erasing shape information 16 TORCH_API std::shared_ptr<Graph> normalizeGraphForCache( 17 const std::shared_ptr<Graph>& graph); 18 19 // Stores the given graph, returning the key used to access it 20 TORCH_API int64_t store(std::shared_ptr<Graph> graph); 21 22 // Given a graph, find a KernelSpec based on it 23 TORCH_API std::optional<KernelSpec*> lookupGraph( 24 const std::shared_ptr<Graph>& graph); 25 26 // Returns the graph corresponding to the given key (if it exists) 27 TORCH_API std::optional<KernelSpec*> retrieve(const int64_t key); 28 29 // Returns the size of the fusion key -> KernelSpec cache. 30 // Only used for testing. 31 TORCH_API int64_t debugNumCachedKernelSpecs(); 32 33 } // namespace torch::jit::fuser 34