1 #include <gtest/gtest.h> 2 3 #include <torch/csrc/jit/frontend/parser.h> 4 #include <torch/csrc/jit/frontend/resolver.h> 5 6 namespace torch { 7 namespace jit { 8 constexpr c10::string_view testSource = R"JIT( 9 class FooTest: 10 def __init__(self, x): 11 self.x = x 12 13 def get_x(self): 14 return self.x 15 16 an_attribute : Tensor 17 )JIT"; 18 TEST(ClassParserTest,Basic)19TEST(ClassParserTest, Basic) { 20 Parser p(std::make_shared<Source>(testSource)); 21 std::vector<Def> definitions; 22 std::vector<Resolver> resolvers; 23 24 const auto classDef = ClassDef(p.parseClass()); 25 p.lexer().expect(TK_EOF); 26 27 ASSERT_EQ(classDef.name().name(), "FooTest"); 28 ASSERT_EQ(classDef.body().size(), 3); 29 ASSERT_EQ(Def(classDef.body()[0]).name().name(), "__init__"); 30 ASSERT_EQ(Def(classDef.body()[1]).name().name(), "get_x"); 31 ASSERT_EQ( 32 Var(Assign(classDef.body()[2]).lhs()).name().name(), "an_attribute"); 33 ASSERT_FALSE(Assign(classDef.body()[2]).rhs().present()); 34 ASSERT_TRUE(Assign(classDef.body()[2]).type().present()); 35 } 36 } // namespace jit 37 } // namespace torch 38