1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch::jit { 6 7 // propagate autograd zero information through a gradient graph and 8 // remove grad_of blocks if present. 9 // Note: this is a very limited pass. It only propagates autograd zeros for 10 // operations generated by the symbolic autodiff code and cleans up 11 // AutogradAdds when possible. Outputs of other nodes are conservatively 12 // marked Unknown and not optimized. 13 TORCH_API void specializeAutogradZero(std::shared_ptr<Graph> g); 14 15 struct ProfilingRecord; 16 17 TORCH_API void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr); 18 19 } // namespace torch::jit 20