xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/interface.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/Config.h>
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/passes/pass_manager.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace fuser {
9 namespace onednn {
10 
11 static std::atomic<bool> onednn_enabled{false};
12 
getLlgaEnabled()13 static std::atomic<bool>& getLlgaEnabled() {
14   return onednn_enabled;
15 }
16 
17 C10_EXPORT void fuseGraph(std::shared_ptr<Graph>& g);
18 
19 } // namespace onednn
20 } // namespace fuser
21 
22 struct C10_EXPORT RegisterLlgaFuseGraph
23     : public PassManager<RegisterLlgaFuseGraph> {
setEnabledRegisterLlgaFuseGraph24   static bool setEnabled(bool enabled) {
25     TORCH_CHECK(
26         AT_MKLDNN_ENABLED(),
27         "Running oneDNN Graph fuser is only supported with MKLDNN builds.");
28     bool oldState = fuser::onednn::getLlgaEnabled();
29     fuser::onednn::getLlgaEnabled() = enabled;
30     if (enabled) {
31       registerPass(fuser::onednn::fuseGraph);
32     } else {
33       clearPass();
34     }
35     return oldState;
36   }
37 
isEnabledRegisterLlgaFuseGraph38   static bool isEnabled() {
39     return fuser::onednn::getLlgaEnabled();
40   }
41 
42   // override PassManager::registerPass to register pre-pass
registerPassRegisterLlgaFuseGraph43   static bool registerPass(GraphPass p) {
44     if (!isRegistered()) {
45       passID(registerPrePass(std::move(p)), true);
46       isRegistered(true);
47       return false;
48     }
49     return true;
50   }
51 
52   // override PassManager::clearPass to clear pre-pass
clearPassRegisterLlgaFuseGraph53   static void clearPass() {
54     if (isRegistered()) {
55       clearPrePass(passID());
56       isRegistered(true);
57     }
58   }
59 };
60 
61 } // namespace jit
62 } // namespace torch
63