xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/variadic_ops.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // Try to replace an op that takes a list input with another op that takes a
8 // variadic number of arguments.
9 TORCH_API bool UseVariadicOp(
10     const std::shared_ptr<Graph>& graph,
11     NodeKind op,
12     NodeKind variadic_op);
13 
14 TORCH_API bool RemoveListMutationAndUseVariadicOp(
15     const std::shared_ptr<Graph>& graph,
16     NodeKind op,
17     NodeKind variadic_op);
18 
19 // Convenient functions for replacing aten::stack/aten::cat with their
20 // variadic versions.
21 TORCH_API bool UseVariadicCat(const std::shared_ptr<Graph>& graph);
22 TORCH_API bool RemoveListMutationAndUseVariadicCat(
23     const std::shared_ptr<Graph>& graph);
24 
25 TORCH_API bool UseVariadicStack(const std::shared_ptr<Graph>& graph);
26 TORCH_API bool RemoveListMutationAndUseVariadicStack(
27     const std::shared_ptr<Graph>& graph);
28 
29 } // namespace torch::jit
30