xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_class_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/csrc/jit/frontend/parser.h>
4 #include <torch/csrc/jit/frontend/resolver.h>
5 
6 namespace torch {
7 namespace jit {
8 constexpr c10::string_view testSource = R"JIT(
9   class FooTest:
10     def __init__(self, x):
11       self.x = x
12 
13     def get_x(self):
14       return self.x
15 
16     an_attribute : Tensor
17 )JIT";
18 
TEST(ClassParserTest,Basic)19 TEST(ClassParserTest, Basic) {
20   Parser p(std::make_shared<Source>(testSource));
21   std::vector<Def> definitions;
22   std::vector<Resolver> resolvers;
23 
24   const auto classDef = ClassDef(p.parseClass());
25   p.lexer().expect(TK_EOF);
26 
27   ASSERT_EQ(classDef.name().name(), "FooTest");
28   ASSERT_EQ(classDef.body().size(), 3);
29   ASSERT_EQ(Def(classDef.body()[0]).name().name(), "__init__");
30   ASSERT_EQ(Def(classDef.body()[1]).name().name(), "get_x");
31   ASSERT_EQ(
32       Var(Assign(classDef.body()[2]).lhs()).name().name(), "an_attribute");
33   ASSERT_FALSE(Assign(classDef.body()[2]).rhs().present());
34   ASSERT_TRUE(Assign(classDef.body()[2]).type().present());
35 }
36 } // namespace jit
37 } // namespace torch
38