1 #include <gtest/gtest.h> 2 3 #include "test/cpp/jit/test_utils.h" 4 5 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h" 6 7 namespace torch { 8 namespace jit { 9 TEST(CreateAutodiffSubgraphsTest,Basic)10TEST(CreateAutodiffSubgraphsTest, Basic) { 11 auto graph = build_lstm(); 12 CreateAutodiffSubgraphs(graph, /*threshold=*/2); 13 // all of the ops are within the DifferentiableGraph 14 testing::FileCheck() 15 .check_not("aten::mm") 16 ->check_not("aten::sigmoid") 17 ->check_not("aten::tanh") 18 ->check_not("aten::mul") 19 ->check("DifferentiableGraph") 20 ->check_next("return") 21 ->run(*graph); 22 } 23 24 } // namespace jit 25 } // namespace torch 26