xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_custom_class.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <test/cpp/jit/test_custom_class_registrations.h>
4 #include <torch/csrc/jit/passes/freeze_module.h>
5 #include <torch/custom_class.h>
6 #include <torch/script.h>
7 
8 #include <iostream>
9 #include <string>
10 #include <vector>
11 
12 namespace torch {
13 namespace jit {
14 
TEST(CustomClassTest,TorchbindIValueAPI)15 TEST(CustomClassTest, TorchbindIValueAPI) {
16   script::Module m("m");
17 
18   // test make_custom_class API
19   auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
20       std::vector<std::string>{"foo", "bar"});
21   m.define(R"(
22     def forward(self, s : __torch__.torch.classes._TorchScriptTesting._StackString):
23       return s.pop(), s
24   )");
25 
26   auto test_with_obj = [&m](IValue obj, std::string expected) {
27     auto res = m.run_method("forward", obj);
28     auto tup = res.toTuple();
29     AT_ASSERT(tup->elements().size() == 2);
30     auto str = tup->elements()[0].toStringRef();
31     auto other_obj =
32         tup->elements()[1].toCustomClass<MyStackClass<std::string>>();
33     AT_ASSERT(str == expected);
34     auto ref_obj = obj.toCustomClass<MyStackClass<std::string>>();
35     AT_ASSERT(other_obj.get() == ref_obj.get());
36   };
37 
38   test_with_obj(custom_class_obj, "bar");
39 
40   // test IValue() API
41   auto my_new_stack = c10::make_intrusive<MyStackClass<std::string>>(
42       std::vector<std::string>{"baz", "boo"});
43   auto new_stack_ivalue = c10::IValue(my_new_stack);
44 
45   test_with_obj(new_stack_ivalue, "boo");
46 }
47 
TEST(CustomClassTest,ScalarTypeClass)48 TEST(CustomClassTest, ScalarTypeClass) {
49   script::Module m("m");
50 
51   // test make_custom_class API
52   auto cc = make_custom_class<ScalarTypeClass>(at::kFloat);
53   m.register_attribute("s", cc.type(), cc, false);
54 
55   std::ostringstream oss;
56   m.save(oss);
57   std::istringstream iss(oss.str());
58   caffe2::serialize::IStreamAdapter adapter{&iss};
59   auto loaded_module = torch::jit::load(iss, torch::kCPU);
60 }
61 
62 class TorchBindTestClass : public torch::jit::CustomClassHolder {
63  public:
get()64   std::string get() {
65     return "Hello, I am your test custom class";
66   }
67 };
68 
69 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
70 constexpr char class_doc_string[] = R"(
71   I am docstring for TorchBindTestClass
72   Args:
73       What is an argument? Oh never mind, I don't take any.
74 
75   Return:
76       How would I know? I am just a holder of some meaningless test methods.
77   )";
78 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
79 constexpr char method_doc_string[] =
80     "I am docstring for TorchBindTestClass get_with_docstring method";
81 
82 namespace {
83 static auto reg =
84     torch::class_<TorchBindTestClass>(
85         "_TorchBindTest",
86         "_TorchBindTestClass",
87         class_doc_string)
88         .def("get", &TorchBindTestClass::get)
89         .def("get_with_docstring", &TorchBindTestClass::get, method_doc_string);
90 
91 } // namespace
92 
93 // Tests DocString is properly propagated when defining CustomClasses.
TEST(CustomClassTest,TestDocString)94 TEST(CustomClassTest, TestDocString) {
95   auto class_type = getCustomClass(
96       "__torch__.torch.classes._TorchBindTest._TorchBindTestClass");
97   AT_ASSERT(class_type);
98   AT_ASSERT(class_type->doc_string() == class_doc_string);
99 
100   AT_ASSERT(class_type->getMethod("get").doc_string().empty());
101   AT_ASSERT(
102       class_type->getMethod("get_with_docstring").doc_string() ==
103       method_doc_string);
104 }
105 
TEST(CustomClassTest,Serialization)106 TEST(CustomClassTest, Serialization) {
107   script::Module m("m");
108 
109   // test make_custom_class API
110   auto custom_class_obj = make_custom_class<MyStackClass<std::string>>(
111       std::vector<std::string>{"foo", "bar"});
112   m.register_attribute(
113       "s",
114       custom_class_obj.type(),
115       custom_class_obj,
116       // NOLINTNEXTLINE(bugprone-argument-comment)
117       /*is_parameter=*/false);
118   m.define(R"(
119     def forward(self):
120       return self.s.return_a_tuple()
121   )");
122 
123   auto test_with_obj = [](script::Module& mod) {
124     auto res = mod.run_method("forward");
125     auto tup = res.toTuple();
126     AT_ASSERT(tup->elements().size() == 2);
127     auto i = tup->elements()[1].toInt();
128     AT_ASSERT(i == 123);
129   };
130 
131   auto frozen_m = torch::jit::freeze_module(m.clone());
132 
133   test_with_obj(m);
134   test_with_obj(frozen_m);
135 
136   std::ostringstream oss;
137   m.save(oss);
138   std::istringstream iss(oss.str());
139   caffe2::serialize::IStreamAdapter adapter{&iss};
140   auto loaded_module = torch::jit::load(iss, torch::kCPU);
141 
142   std::ostringstream oss_frozen;
143   frozen_m.save(oss_frozen);
144   std::istringstream iss_frozen(oss_frozen.str());
145   caffe2::serialize::IStreamAdapter adapter_frozen{&iss_frozen};
146   auto loaded_frozen_module = torch::jit::load(iss_frozen, torch::kCPU);
147 }
148 
149 } // namespace jit
150 } // namespace torch
151