1 #include <gtest/gtest.h> 2 3 #include <torch/csrc/jit/api/compilation_unit.h> 4 #include <torch/csrc/jit/api/module.h> 5 #include <torch/csrc/jit/passes/inliner.h> 6 #include <torch/csrc/jit/testing/file_check.h> 7 8 const auto testSource = R"JIT( 9 def foo1(x): 10 print("one") 11 return x 12 13 def foo2(x): 14 print("two") 15 return foo1(x) 16 17 def foo3(x): 18 print("three") 19 return foo2(x) 20 )JIT"; 21 22 namespace torch { 23 namespace jit { 24 using namespace testing; 25 26 struct InlinerGuard { InlinerGuardtorch::jit::InlinerGuard27 explicit InlinerGuard(bool shouldInline) 28 : oldState_(getInlineEverythingMode()) { 29 getInlineEverythingMode() = shouldInline; 30 } 31 ~InlinerGuardtorch::jit::InlinerGuard32 ~InlinerGuard() { 33 getInlineEverythingMode() = oldState_; 34 } 35 36 bool oldState_; 37 }; 38 TEST(InlinerTest,Basic)39TEST(InlinerTest, Basic) { 40 // disable automatic inlining so we can test it manually 41 InlinerGuard guard(/*shouldInline=*/false); 42 43 CompilationUnit cu(testSource); 44 auto& fn = cu.get_function("foo3"); 45 46 auto g = toGraphFunction(fn).graph(); 47 Inline(*g); 48 FileCheck().check_count("prim::Print", 3)->run(*g); 49 } 50 } // namespace jit 51 } // namespace torch 52