xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/constant_propagation.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // Runs constant propagation on all objects unless ignore_custom_classes is
8 // specified as true, in which case user defined classes are skipped.  This is
9 // useful to prevent early fusion of packing operations, which end up lowering
10 // away information about their constructors (e.g. packed::linear_clamp_prepack
11 // and prepacked::conv2d_clamp_prepack)
12 // Returns True if the pass made a change to the graph
13 TORCH_API bool ConstantPropagation(
14     std::shared_ptr<Graph>& graph,
15     bool ignore_custom_classes = false);
16 
17 // runs constant propagation only on ops that have non-aliasing inputs & outputs
18 // Returns True if the pass made a change to the graph
19 TORCH_API bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph);
20 
21 // Runs the node if its inputs are constants. Callers of this function must
22 // make their own determination if constant prop is appropriate - for example
23 // non-deterministic ops or ops with side effects.  If ignore_custom_classes is
24 // specified, nodes that output user defined classes are not run.
25 TORCH_API std::optional<Stack> runNodeIfInputsAreConstant(
26     const Node* node,
27     bool ignore_custom_classes = false,
28     AliasDb* db = nullptr);
29 
30 } // namespace torch::jit
31