1 #pragma once 2 3 #include <torch/csrc/Export.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 6 #include <cstddef> 7 8 namespace torch::jit { 9 10 // insert GraphExecutor nodes that group together 11 // subgraphs that are differentiable by the jit's autodiff passes 12 // threshold - minimum number of nodes that will appear in a block 13 // returns all differentiable blocks that have been found 14 TORCH_API std::vector<Node*> CreateAutodiffSubgraphs( 15 const std::shared_ptr<Graph>& graph, 16 size_t threshold = 2); 17 } // namespace torch::jit 18