xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_inliner.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/api/compilation_unit.h>
4 #include <torch/csrc/jit/api/module.h>
5 #include <torch/csrc/jit/passes/inliner.h>
6 #include <torch/csrc/jit/testing/file_check.h>
7 
8 const auto testSource = R"JIT(
9 def foo1(x):
10     print("one")
11     return x
12 
13 def foo2(x):
14     print("two")
15     return foo1(x)
16 
17 def foo3(x):
18     print("three")
19     return foo2(x)
20 )JIT";
21 
22 namespace torch {
23 namespace jit {
24 using namespace testing;
25 
26 struct InlinerGuard {
InlinerGuardtorch::jit::InlinerGuard27   explicit InlinerGuard(bool shouldInline)
28       : oldState_(getInlineEverythingMode()) {
29     getInlineEverythingMode() = shouldInline;
30   }
31 
~InlinerGuardtorch::jit::InlinerGuard32   ~InlinerGuard() {
33     getInlineEverythingMode() = oldState_;
34   }
35 
36   bool oldState_;
37 };
38 
TEST(InlinerTest,Basic)39 TEST(InlinerTest, Basic) {
40   // disable automatic inlining so we can test it manually
41   InlinerGuard guard(/*shouldInline=*/false);
42 
43   CompilationUnit cu(testSource);
44   auto& fn = cu.get_function("foo3");
45 
46   auto g = toGraphFunction(fn).graph();
47   Inline(*g);
48   FileCheck().check_count("prim::Print", 3)->run(*g);
49 }
50 } // namespace jit
51 } // namespace torch
52