1 /** \brief Fusing linear patterns as single at::linear for easier pattern 2 * matching in later passes 3 */ 4 #pragma once 5 6 #include <torch/csrc/jit/ir/ir.h> 7 8 namespace torch::jit { 9 10 /** \brief Match the at::linear pattern and fuse it into a single at::linear 11 * This pass fuse the addmm or matmul + add generated by JIT back to linear 12 * This pass can be deleted once the JIT can emit the aten::linear in the future 13 */ 14 TORCH_API void FuseLinear(std::shared_ptr<Graph>& graph); 15 16 /** Swap functional linear CallFunctions to aten::linear 17 */ 18 TORCH_API void SwapFunctionalLinear(std::shared_ptr<Graph>& graph); 19 /** Swap all functional linear CallFunctions in module 20 */ 21 TORCH_API void SwapFunctionalLinear(Module& module); 22 } // namespace torch::jit 23