xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fuse_linear.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /** \brief Fusing linear patterns as single at::linear for easier pattern
2  * matching in later passes
3  */
4 #pragma once
5 
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 namespace torch::jit {
9 
10 /** \brief Match the at::linear pattern and fuse it into a single at::linear
11  * This pass fuse the addmm or matmul + add generated by JIT back to linear
12  * This pass can be deleted once the JIT can emit the aten::linear in the future
13  */
14 TORCH_API void FuseLinear(std::shared_ptr<Graph>& graph);
15 
16 /** Swap functional linear CallFunctions to aten::linear
17  */
18 TORCH_API void SwapFunctionalLinear(std::shared_ptr<Graph>& graph);
19 /** Swap all functional linear CallFunctions in module
20  */
21 TORCH_API void SwapFunctionalLinear(Module& module);
22 } // namespace torch::jit
23