xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/specialize_autogradzero.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch::jit {
6 
7 // propagate autograd zero information through a gradient graph and
8 // remove grad_of blocks if present.
9 // Note: this is a very limited pass. It only propagates autograd zeros for
10 // operations generated by the symbolic autodiff code and cleans up
11 // AutogradAdds when possible. Outputs of other nodes are conservatively
12 // marked Unknown and not optimized.
13 TORCH_API void specializeAutogradZero(std::shared_ptr<Graph> g);
14 
15 struct ProfilingRecord;
16 
17 TORCH_API void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr);
18 
19 } // namespace torch::jit
20