#include #include #include #include #include #include #include namespace torch { namespace jit { static const std::vector subMethodSrcs = {R"JIT( def one(self, x: Tensor, y: Tensor) -> Tensor: return x + y + 1 def forward(self, x: Tensor) -> Tensor: return x )JIT"}; static const std::string parentForward = R"JIT( def forward(self, x: Tensor) -> Tensor: return self.subMod.forward(x) )JIT"; static constexpr c10::string_view moduleInterfaceSrc = R"JIT( class OneForward(ModuleInterface): def one(self, x: Tensor, y: Tensor) -> Tensor: pass def forward(self, x: Tensor) -> Tensor: pass )JIT"; static void import_libs( std::shared_ptr cu, const std::string& class_name, const std::shared_ptr& src, const std::vector& tensor_table) { SourceImporter si( cu, &tensor_table, [&](const std::string& name) -> std::shared_ptr { return src; }, /*version=*/2); si.loadType(QualifiedName(class_name)); } TEST(InterfaceTest, ModuleInterfaceSerialization) { auto cu = std::make_shared(); Module parentMod("parentMod", cu); Module subMod("subMod", cu); std::vector constantTable; import_libs( cu, "__torch__.OneForward", std::make_shared(moduleInterfaceSrc), constantTable); for (const std::string& method : subMethodSrcs) { subMod.define(method, nativeResolver()); } parentMod.register_attribute( "subMod", cu->get_interface("__torch__.OneForward"), subMod._ivalue(), // NOLINTNEXTLINE(bugprone-argument-comment) /*is_parameter=*/false); parentMod.define(parentForward, nativeResolver()); ASSERT_TRUE(parentMod.hasattr("subMod")); std::stringstream ss; parentMod.save(ss); Module reloaded_mod = jit::load(ss); ASSERT_TRUE(reloaded_mod.hasattr("subMod")); InterfaceTypePtr submodType = reloaded_mod.type()->getAttribute("subMod")->cast(); ASSERT_TRUE(submodType->is_module()); } } // namespace jit } // namespace torch