1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit { 6 7 // Try to replace an op that takes a list input with another op that takes a 8 // variadic number of arguments. 9 TORCH_API bool UseVariadicOp( 10 const std::shared_ptr<Graph>& graph, 11 NodeKind op, 12 NodeKind variadic_op); 13 14 TORCH_API bool RemoveListMutationAndUseVariadicOp( 15 const std::shared_ptr<Graph>& graph, 16 NodeKind op, 17 NodeKind variadic_op); 18 19 // Convenient functions for replacing aten::stack/aten::cat with their 20 // variadic versions. 21 TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph); 22 TORCH_API bool RemoveListMutationAndUseVariadicCat( 23 const std::shared_ptr<Graph>& graph); 24 25 TORCH_API bool UseVariadicStack(const std::shared_ptr<Graph>& graph); 26 TORCH_API bool RemoveListMutationAndUseVariadicStack( 27 const std::shared_ptr<Graph>& graph); 28 29 } // namespace torch::jit 30