xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_class_import.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/core/qualified_name.h>
4 #include <test/cpp/jit/test_utils.h>
5 #include <torch/csrc/jit/frontend/resolver.h>
6 #include <torch/csrc/jit/serialization/import_source.h>
7 #include <torch/torch.h>
8 
9 namespace torch {
10 namespace jit {
11 
12 static constexpr c10::string_view classSrcs1 = R"JIT(
13 class FooNestedTest:
14     def __init__(self, y):
15         self.y = y
16 
17 class FooNestedTest2:
18     def __init__(self, y):
19         self.y = y
20         self.nested = __torch__.FooNestedTest(y)
21 
22 class FooTest:
23     def __init__(self, x):
24         self.class_attr = __torch__.FooNestedTest(x)
25         self.class_attr2 = __torch__.FooNestedTest2(x)
26         self.x = self.class_attr.y + self.class_attr2.y
27 )JIT";
28 
29 static constexpr c10::string_view classSrcs2 = R"JIT(
30 class FooTest:
31     def __init__(self, x):
32       self.dx = x
33 )JIT";
34 
import_libs(std::shared_ptr<CompilationUnit> cu,const std::string & class_name,const std::shared_ptr<Source> & src,const std::vector<at::IValue> & tensor_table)35 static void import_libs(
36     std::shared_ptr<CompilationUnit> cu,
37     const std::string& class_name,
38     const std::shared_ptr<Source>& src,
39     const std::vector<at::IValue>& tensor_table) {
40   SourceImporter si(
41       cu,
42       &tensor_table,
43       [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
44       /*version=*/2);
45   si.loadType(QualifiedName(class_name));
46 }
47 
TEST(ClassImportTest,Basic)48 TEST(ClassImportTest, Basic) {
49   auto cu1 = std::make_shared<CompilationUnit>();
50   auto cu2 = std::make_shared<CompilationUnit>();
51   std::vector<at::IValue> constantTable;
52   // Import different versions of FooTest into two namespaces.
53   import_libs(
54       cu1,
55       "__torch__.FooTest",
56       std::make_shared<Source>(classSrcs1),
57       constantTable);
58   import_libs(
59       cu2,
60       "__torch__.FooTest",
61       std::make_shared<Source>(classSrcs2),
62       constantTable);
63 
64   // We should get the correct version of `FooTest` for whichever namespace we
65   // are referencing
66   c10::QualifiedName base("__torch__");
67   auto classType1 = cu1->get_class(c10::QualifiedName(base, "FooTest"));
68   ASSERT_TRUE(classType1->hasAttribute("x"));
69   ASSERT_FALSE(classType1->hasAttribute("dx"));
70 
71   auto classType2 = cu2->get_class(c10::QualifiedName(base, "FooTest"));
72   ASSERT_TRUE(classType2->hasAttribute("dx"));
73   ASSERT_FALSE(classType2->hasAttribute("x"));
74 
75   // We should only see FooNestedTest in the first namespace
76   auto c = cu1->get_class(c10::QualifiedName(base, "FooNestedTest"));
77   ASSERT_TRUE(c);
78 
79   c = cu2->get_class(c10::QualifiedName(base, "FooNestedTest"));
80   ASSERT_FALSE(c);
81 }
82 
TEST(ClassImportTest,ScriptObject)83 TEST(ClassImportTest, ScriptObject) {
84   Module m1("m1");
85   Module m2("m2");
86   std::vector<at::IValue> constantTable;
87   import_libs(
88       m1._ivalue()->compilation_unit(),
89       "__torch__.FooTest",
90       std::make_shared<Source>(classSrcs1),
91       constantTable);
92   import_libs(
93       m2._ivalue()->compilation_unit(),
94       "__torch__.FooTest",
95       std::make_shared<Source>(classSrcs2),
96       constantTable);
97 
98   // Incorrect arguments for constructor should throw
99   c10::QualifiedName base("__torch__");
100   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
101   ASSERT_ANY_THROW(m1.create_class(c10::QualifiedName(base, "FooTest"), {1}));
102   auto x = torch::ones({2, 3});
103   auto obj = m2.create_class(c10::QualifiedName(base, "FooTest"), x).toObject();
104   auto dx = obj->getAttr("dx");
105   ASSERT_TRUE(almostEqual(x, dx.toTensor()));
106 
107   auto new_x = torch::rand({2, 3});
108   obj->setAttr("dx", new_x);
109   auto new_dx = obj->getAttr("dx");
110   ASSERT_TRUE(almostEqual(new_x, new_dx.toTensor()));
111 }
112 
113 static const auto methodSrc = R"JIT(
114 def __init__(self, x):
115     return x
116 )JIT";
117 
TEST(ClassImportTest,ClassDerive)118 TEST(ClassImportTest, ClassDerive) {
119   auto cu = std::make_shared<CompilationUnit>();
120   auto cls = ClassType::create("foo.bar", cu);
121   const auto self = SimpleSelf(cls);
122   auto methods = cu->define("foo.bar", methodSrc, nativeResolver(), &self);
123   auto method = methods[0];
124   cls->addAttribute("attr", TensorType::get());
125   ASSERT_TRUE(cls->findMethod(method->name()));
126 
127   // Refining a new class should retain attributes and methods
128   auto newCls = cls->refine({TensorType::get()});
129   ASSERT_TRUE(newCls->hasAttribute("attr"));
130   ASSERT_TRUE(newCls->findMethod(method->name()));
131 
132   auto newCls2 = cls->withContained({TensorType::get()})->expect<ClassType>();
133   ASSERT_TRUE(newCls2->hasAttribute("attr"));
134   ASSERT_TRUE(newCls2->findMethod(method->name()));
135 }
136 
137 static constexpr c10::string_view torchbindSrc = R"JIT(
138 class FooBar1234(Module):
139   __parameters__ = []
140   f : __torch__.torch.classes._TorchScriptTesting._StackString
141   training : bool
142   def forward(self: __torch__.FooBar1234) -> str:
143     return (self.f).top()
144 )JIT";
145 
TEST(ClassImportTest,CustomClass)146 TEST(ClassImportTest, CustomClass) {
147   auto cu1 = std::make_shared<CompilationUnit>();
148   std::vector<at::IValue> constantTable;
149   // Import different versions of FooTest into two namespaces.
150   import_libs(
151       cu1,
152       "__torch__.FooBar1234",
153       std::make_shared<Source>(torchbindSrc),
154       constantTable);
155 }
156 
157 } // namespace jit
158 } // namespace torch
159