xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_interface.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_utils.h>
4 
5 #include <ATen/core/qualified_name.h>
6 #include <torch/csrc/jit/frontend/resolver.h>
7 #include <torch/csrc/jit/serialization/import.h>
8 #include <torch/csrc/jit/serialization/import_source.h>
9 #include <torch/torch.h>
10 
11 namespace torch {
12 namespace jit {
13 
14 static const std::vector<std::string> subMethodSrcs = {R"JIT(
15 def one(self, x: Tensor, y: Tensor) -> Tensor:
16     return x + y + 1
17 
18 def forward(self, x: Tensor) -> Tensor:
19     return x
20 )JIT"};
21 static const std::string parentForward = R"JIT(
22 def forward(self, x: Tensor) -> Tensor:
23     return self.subMod.forward(x)
24 )JIT";
25 
26 static constexpr c10::string_view moduleInterfaceSrc = R"JIT(
27 class OneForward(ModuleInterface):
28     def one(self, x: Tensor, y: Tensor) -> Tensor:
29         pass
30     def forward(self, x: Tensor) -> Tensor:
31         pass
32 )JIT";
33 
import_libs(std::shared_ptr<CompilationUnit> cu,const std::string & class_name,const std::shared_ptr<Source> & src,const std::vector<at::IValue> & tensor_table)34 static void import_libs(
35     std::shared_ptr<CompilationUnit> cu,
36     const std::string& class_name,
37     const std::shared_ptr<Source>& src,
38     const std::vector<at::IValue>& tensor_table) {
39   SourceImporter si(
40       cu,
41       &tensor_table,
42       [&](const std::string& name) -> std::shared_ptr<Source> { return src; },
43       /*version=*/2);
44   si.loadType(QualifiedName(class_name));
45 }
46 
TEST(InterfaceTest,ModuleInterfaceSerialization)47 TEST(InterfaceTest, ModuleInterfaceSerialization) {
48   auto cu = std::make_shared<CompilationUnit>();
49   Module parentMod("parentMod", cu);
50   Module subMod("subMod", cu);
51 
52   std::vector<at::IValue> constantTable;
53   import_libs(
54       cu,
55       "__torch__.OneForward",
56       std::make_shared<Source>(moduleInterfaceSrc),
57       constantTable);
58 
59   for (const std::string& method : subMethodSrcs) {
60     subMod.define(method, nativeResolver());
61   }
62   parentMod.register_attribute(
63       "subMod",
64       cu->get_interface("__torch__.OneForward"),
65       subMod._ivalue(),
66       // NOLINTNEXTLINE(bugprone-argument-comment)
67       /*is_parameter=*/false);
68   parentMod.define(parentForward, nativeResolver());
69   ASSERT_TRUE(parentMod.hasattr("subMod"));
70   std::stringstream ss;
71   parentMod.save(ss);
72   Module reloaded_mod = jit::load(ss);
73   ASSERT_TRUE(reloaded_mod.hasattr("subMod"));
74   InterfaceTypePtr submodType =
75       reloaded_mod.type()->getAttribute("subMod")->cast<InterfaceType>();
76   ASSERT_TRUE(submodType->is_module());
77 }
78 
79 } // namespace jit
80 } // namespace torch
81