xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_create_autodiff_subgraphs.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include "test/cpp/jit/test_utils.h"
4 
5 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
6 
7 namespace torch {
8 namespace jit {
9 
TEST(CreateAutodiffSubgraphsTest,Basic)10 TEST(CreateAutodiffSubgraphsTest, Basic) {
11   auto graph = build_lstm();
12   CreateAutodiffSubgraphs(graph, /*threshold=*/2);
13   // all of the ops are within the DifferentiableGraph
14   testing::FileCheck()
15       .check_not("aten::mm")
16       ->check_not("aten::sigmoid")
17       ->check_not("aten::tanh")
18       ->check_not("aten::mul")
19       ->check("DifferentiableGraph")
20       ->check_next("return")
21       ->run(*graph);
22 }
23 
24 } // namespace jit
25 } // namespace torch
26