1 #pragma once 2 3 #include <ATen/Config.h> 4 #include <torch/csrc/jit/api/module.h> 5 #include <torch/csrc/jit/ir/ir.h> 6 #include <torch/csrc/jit/passes/subgraph_rewrite.h> 7 8 #if AT_MKLDNN_ENABLED() 9 10 #include <ideep/tensor.hpp> 11 12 #endif // AT_MKLDNN_ENABLED() 13 14 namespace torch::jit { 15 16 #if AT_MKLDNN_ENABLED() 17 18 namespace mkldnn { 19 20 const static std::map<std::string, std::vector<torch::jit::MatchFilter>> 21 fusion_rewrite_map = { 22 {"none", {}}, 23 {"relu", {}}, 24 }; 25 26 } // namespace mkldnn 27 28 #endif // AT_MKLDNN_ENABLED() 29 30 void FuseConvWithEltwise(std::shared_ptr<Graph>& graph); 31 32 } // namespace torch::jit 33