1 #pragma once 2 #include <ATen/Config.h> 3 #include <torch/csrc/jit/ir/ir.h> 4 #include <torch/csrc/jit/passes/pass_manager.h> 5 6 namespace torch { 7 namespace jit { 8 namespace fuser { 9 namespace onednn { 10 11 static std::atomic<bool> onednn_enabled{false}; 12 getLlgaEnabled()13static std::atomic<bool>& getLlgaEnabled() { 14 return onednn_enabled; 15 } 16 17 C10_EXPORT void fuseGraph(std::shared_ptr<Graph>& g); 18 19 } // namespace onednn 20 } // namespace fuser 21 22 struct C10_EXPORT RegisterLlgaFuseGraph 23 : public PassManager<RegisterLlgaFuseGraph> { setEnabledRegisterLlgaFuseGraph24 static bool setEnabled(bool enabled) { 25 TORCH_CHECK( 26 AT_MKLDNN_ENABLED(), 27 "Running oneDNN Graph fuser is only supported with MKLDNN builds."); 28 bool oldState = fuser::onednn::getLlgaEnabled(); 29 fuser::onednn::getLlgaEnabled() = enabled; 30 if (enabled) { 31 registerPass(fuser::onednn::fuseGraph); 32 } else { 33 clearPass(); 34 } 35 return oldState; 36 } 37 isEnabledRegisterLlgaFuseGraph38 static bool isEnabled() { 39 return fuser::onednn::getLlgaEnabled(); 40 } 41 42 // override PassManager::registerPass to register pre-pass registerPassRegisterLlgaFuseGraph43 static bool registerPass(GraphPass p) { 44 if (!isRegistered()) { 45 passID(registerPrePass(std::move(p)), true); 46 isRegistered(true); 47 return false; 48 } 49 return true; 50 } 51 52 // override PassManager::clearPass to clear pre-pass clearPassRegisterLlgaFuseGraph53 static void clearPass() { 54 if (isRegistered()) { 55 clearPrePass(passID()); 56 isRegistered(true); 57 } 58 } 59 }; 60 61 } // namespace jit 62 } // namespace torch 63