xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_backend.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/jit/test_utils.h>
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/backends/backend_detail.h>
5 #include <torch/csrc/jit/mobile/import.h>
6 #include <torch/csrc/jit/serialization/import.h>
7 #include <torch/torch.h>
8 
9 // Tests go in torch::jit
10 namespace torch {
11 namespace jit {
TEST(BackendTest,ToBackend)12 TEST(BackendTest, ToBackend) {
13   Module m("m");
14   m.define(R"(
15     def forward(self, x, h):
16         return self.accum(x, h), self.sub_accum(x, h)
17 
18     def accum(self, x, h):
19         return x + h
20 
21     def sub_accum(self, x, h):
22         return x - h
23   )");
24 
25   std::vector<IValue> inputs;
26   inputs.emplace_back(2.0 * torch::ones({}));
27   inputs.emplace_back(1.0 * torch::ones({}));
28   auto ref = m.forward(inputs).toTupleRef().elements().vec();
29 
30   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
31   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
32   fake_dict.insert("", "");
33   compile_spec.insert("forward", fake_dict);
34   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
35   // lowered module
36   auto lm = torch::jit::detail::codegen_backend_module(
37       "test_backend", m, compile_spec, any_dict_ty);
38   // lowered module code:
39   /*
40     class test_backendLoweredModule(Module):
41       __parameters__ = []
42       __buffers__ = []
43       __processed_module : Any
44       __method_compile_spec : Dict[str, Any]
45       __backend : __torch__.torch.classes.__backends__.test_backend
46       __handles : Dict[str, Any]
47       def __create_backend(self: torch.jit.test_backendLoweredModule) -> None:
48         _0 =
49     __torch__.torch.classes.__backends__.test_backend.__new__(__torch__.torch.classes.__backends__.test_backend)
50         _1 = (_0).__init__()
51         self.__backend = _0
52         return None
53       def __getstate__(self: torch.jit.test_backendLoweredModule) ->
54     Tuple[Dict[str, Any], Any]: _2 = (self.__method_compile_spec,
55     self.__processed_module) return _2 def __setstate__(self:
56     torch.jit.test_backendLoweredModule, state: Tuple[Dict[str, Any], Any]) ->
57     None: self.__method_compile_spec = (state)[0] self.__processed_module =
58     (state)[1] _3 = (self).__create_backend() _4 =
59     (self.__backend).compile(self.__processed_module,
60     self.__method_compile_spec, ) self.__handles = _4 return None def
61     forward(self: torch.jit.test_backendLoweredModule, x: Tensor, h: Tensor) ->
62     Tuple[Tensor, Tensor]: _5 = uninitialized(Tensor) typed_inputs =
63     annotate(List[Any], [x, h]) _6 =
64     (self.__backend).execute((self.__handles)["forward"], typed_inputs, ) _7,
65     _8, = _6 _9 = isinstance(_7, Tensor) if _9: _10 = unchecked_cast(Tensor, _7)
66         else:
67           ops.prim.RaiseException("AssertionError: ")
68           _10 = _5
69         _11 = isinstance(_8, Tensor)
70         if _11:
71           _12 = unchecked_cast(Tensor, _8)
72         else:
73           ops.prim.RaiseException("AssertionError: ")
74           _12 = _5
75         return (_10, _12)
76 
77    */
78   auto res = lm.forward(inputs).toTupleRef().elements().vec();
79   AT_ASSERT(res[0].toTensor().equal(ref[0].toTensor()));
80   AT_ASSERT(res[1].toTensor().equal(ref[1].toTensor()));
81 }
82 
TEST(BackendTest,ToBackendNotAvailable)83 TEST(BackendTest, ToBackendNotAvailable) {
84   Module m("m");
85   m.define(R"(
86     def forward(self, x, h):
87         return self.accum(x, h), self.sub_accum(x, h)
88 
89     def accum(self, x, h):
90         return x + h
91 
92     def sub_accum(self, x, h):
93         return x - h
94   )");
95 
96   std::vector<IValue> inputs;
97   inputs.emplace_back(2.0 * torch::ones({}));
98   inputs.emplace_back(1.0 * torch::ones({}));
99   auto ref = m.forward(inputs).toTupleRef().elements().vec();
100 
101   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
102   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
103   fake_dict.insert("", "");
104   compile_spec.insert("forward", fake_dict);
105   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
106   // Produce lowered module (backend not available).
107   // Exception is not thrown at this point.
108   auto lm = torch::jit::detail::codegen_backend_module(
109       "test_backend_unavailable", m, compile_spec, any_dict_ty);
110   // Validate exception is thrown when trying to execute and
111   // the backend is not available.
112   ASSERT_THROWS_WITH_MESSAGE(
113       lm.forward(inputs).toTupleRef().elements(), "Backend is not available.");
114 }
115 
TEST(BackendTest,TestCompiler)116 TEST(BackendTest, TestCompiler) {
117   Module m("m");
118   m.define(R"(
119     def forward(self, x, h):
120         return x + h
121   )");
122 
123   std::vector<IValue> inputs;
124   inputs.emplace_back(2.0 * torch::ones({}));
125   inputs.emplace_back(1.0 * torch::ones({}));
126   auto ref = m.forward(inputs);
127 
128   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
129   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
130   fake_dict.insert("", "");
131   compile_spec.insert("forward", fake_dict);
132   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
133   // lowered module
134   auto lm = torch::jit::detail::codegen_backend_module(
135       "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
136   auto res = lm.forward(inputs);
137   AT_ASSERT(res.toTensor().equal(ref.toTensor()));
138 
139   std::stringstream ss;
140   lm._save_for_mobile(ss);
141   auto mlm = _load_for_mobile(ss);
142   auto mres = mlm.forward(inputs);
143   AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
144 }
145 
TEST(BackendTest,TestCompilerWithStringTable)146 TEST(BackendTest, TestCompilerWithStringTable) {
147   setShouldUseFormatWithStringTable(true);
148   Module m("m");
149   m.define(R"(
150     def forward(self, x, h):
151         return x + h
152   )");
153 
154   std::vector<IValue> inputs;
155   inputs.emplace_back(2.0 * torch::ones({}));
156   inputs.emplace_back(1.0 * torch::ones({}));
157   auto ref = m.forward(inputs);
158 
159   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
160   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
161   fake_dict.insert("", "");
162   compile_spec.insert("forward", fake_dict);
163   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
164   // lowered module
165   auto lm = torch::jit::detail::codegen_backend_module(
166       "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
167   auto res = lm.forward(inputs);
168   AT_ASSERT(res.toTensor().equal(ref.toTensor()));
169 
170   std::stringstream ss;
171   lm._save_for_mobile(ss);
172   auto mlm = _load_for_mobile(ss);
173   auto mres = mlm.forward(inputs);
174   setShouldUseFormatWithStringTable(false);
175   AT_ASSERT(mres.toTensor().equal(ref.toTensor()));
176 }
177 
TEST(BackendTest,TestComposite)178 TEST(BackendTest, TestComposite) {
179   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
180   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
181   fake_dict.insert("", "");
182   compile_spec.insert("forward", fake_dict);
183   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
184 
185   Module m_add("m_add");
186   m_add.define(R"(
187     def forward(self, x, y):
188       return x + y
189   )");
190   auto lm_add = torch::jit::detail::codegen_backend_module(
191       "backend_with_compiler_demo", m_add, compile_spec, any_dict_ty);
192 
193   Module m_sub("m_sub");
194   m_sub.define(R"(
195     def forward(self, x, y):
196       return x - y
197   )");
198   auto lm_sub = torch::jit::detail::codegen_backend_module(
199       "backend_with_compiler_demo", m_sub, compile_spec, any_dict_ty);
200 
201   Module c("C");
202   c.register_module("Add", lm_add);
203   c.register_module("Sub", lm_sub);
204   c.define(R"(
205     def forward(self, x, y):
206       return self.Add.forward(x, y) * self.Sub.forward(x, y)
207   )");
208 
209   std::vector<IValue> inputs;
210   inputs.emplace_back(3.0 * torch::ones({}));
211   inputs.emplace_back(1.0 * torch::ones({}));
212   auto res_jit = c.forward(inputs);
213 
214   std::stringstream ss;
215   c._save_for_mobile(ss);
216   auto mc = _load_for_mobile(ss);
217   auto res_mobile = mc.forward(inputs);
218 
219   AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
220 }
221 
TEST(BackendTest,TestPrimDtype)222 TEST(BackendTest, TestPrimDtype) {
223   Module c("name");
224   c.define(R"(
225     def forward(self, x, y):
226       c = y.dtype
227       return c
228   )");
229 
230   std::vector<IValue> inputs;
231   inputs.emplace_back(3.0 * torch::ones({}));
232   inputs.emplace_back(1.0 * torch::ones({}));
233   auto res_jit = c.forward(inputs);
234 
235   std::stringstream ss;
236   c._save_for_mobile(ss);
237   auto mc = _load_for_mobile(ss);
238   auto res_mobile = mc.forward(inputs);
239 
240   ASSERT_EQ(res_jit.toInt(), res_mobile.toInt());
241 }
242 
getCompositeModuleWithSameNameSubModules()243 Module getCompositeModuleWithSameNameSubModules() {
244   // Two submodules with same module name but different forward and other
245   // functions should be serialized and loaded correctly.
246 
247   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
248   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
249   fake_dict.insert("", "");
250   compile_spec.insert("forward", fake_dict);
251   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
252 
253   Module sub1("m_add");
254   sub1.define(R"(
255     def forward(self, x, y):
256       return x + y
257   )");
258   auto lowered_sub1 = torch::jit::detail::codegen_backend_module(
259       "backend_with_compiler_demo", sub1, compile_spec, any_dict_ty);
260 
261   Module sub2("m_add");
262   sub2.define(R"(
263     def forward(self, x, y):
264       return x - y
265   )");
266   auto lowered_sub2 = torch::jit::detail::codegen_backend_module(
267       "backend_with_compiler_demo", sub2, compile_spec, any_dict_ty);
268 
269   Module c("C");
270   c.register_module("Add", lowered_sub1);
271   c.register_module("Sub", lowered_sub2);
272   c.define(R"(
273     def forward(self, a, b, s:int):
274       c = self.Add.forward(a, b)
275       d = self.Sub.forward(a, b)
276       y = s * (c * d)
277       return y
278   )");
279 
280   return c;
281 }
282 
TEST(BackendTest,TestCompositeWithSetStates)283 TEST(BackendTest, TestCompositeWithSetStates) {
284   Module c = getCompositeModuleWithSameNameSubModules();
285 
286   std::vector<IValue> inputs;
287   inputs.emplace_back(torch::ones({}));
288   inputs.emplace_back(3.0 * torch::ones({}));
289   inputs.emplace_back(3);
290   auto res_jit = c.forward(inputs);
291 
292   std::stringstream ss;
293   c._save_for_mobile(ss);
294   auto mc = _load_for_mobile(ss);
295   auto res_mobile = mc.forward(inputs);
296   AT_ASSERT(res_jit.toTensor().equal(res_mobile.toTensor()));
297 }
298 
TEST(BackendTest,TestConsistencyOfCompositeWithSetStates)299 TEST(BackendTest, TestConsistencyOfCompositeWithSetStates) {
300   Module c = getCompositeModuleWithSameNameSubModules();
301 
302   std::vector<IValue> inputs;
303   inputs.emplace_back(torch::ones({}));
304   inputs.emplace_back(3.0 * torch::ones({}));
305   inputs.emplace_back(3);
306 
307   std::stringstream ss, ss_resave;
308   c._save_for_mobile(ss);
309   auto mc = _load_for_mobile(ss);
310   auto res_mobile = mc.forward(inputs);
311   ss.seekg(0, ss.beg);
312 
313   // check if the methods names are always the same
314   // by reloading the script module and saving it back as mobile
315   // The below checks ensure that the names of Methods
316   // and numerical outputs of mobile and reloaded mobile
317   // modules are same.
318   auto script_module_load = torch::jit::load(ss);
319   script_module_load._save_for_mobile(ss_resave);
320   auto mc_reload = _load_for_mobile(ss_resave);
321   auto res_mobile_reload = mc_reload.forward(inputs);
322 
323   AT_ASSERT(res_mobile_reload.toTensor().equal(res_mobile.toTensor()));
324 
325   auto mc_methods = mc.get_methods();
326   auto mc_reload_methods = mc_reload.get_methods();
327 
328   std::vector<std::string> mc_method_qns, mc_reload_method_qns;
329 
330   auto get_qual_name = [](mobile::Method method) -> std::string {
331     return method.function().qualname().qualifiedName();
332   };
333 
334   std::transform(
335       mc_methods.begin(),
336       mc_methods.end(),
337       std::back_inserter(mc_method_qns),
338       get_qual_name);
339 
340   std::transform(
341       mc_reload_methods.begin(),
342       mc_reload_methods.end(),
343       std::back_inserter(mc_reload_method_qns),
344       get_qual_name);
345 
346   AT_ASSERT(std::equal(
347       mc_method_qns.begin(),
348       mc_method_qns.end(),
349       mc_reload_method_qns.begin()));
350 }
351 
TEST(BackendTest,TestCompilerNotSupport)352 TEST(BackendTest, TestCompilerNotSupport) {
353   Module m("m");
354   m.define(R"(
355     def forward(self, x, h):
356         return x * h
357   )");
358 
359   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
360   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
361   fake_dict.insert("", "");
362   compile_spec.insert("forward", fake_dict);
363   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
364   // lowered module
365   ASSERT_THROWS_WITH_MESSAGE(
366       torch::jit::detail::codegen_backend_module(
367           "backend_with_compiler_demo", m, compile_spec, any_dict_ty),
368       "The node of aten::mul is not supported in this compiler. Source code:");
369 }
370 
TEST(BackendTestDebugInfo,TestCompiler)371 TEST(BackendTestDebugInfo, TestCompiler) {
372   Module m("m");
373   m.define(R"(
374     def forward(self, x, h):
375         return x + h
376   )");
377 
378   std::vector<IValue> inputs;
379   inputs.emplace_back(torch::rand({2, 4}));
380   inputs.emplace_back(torch::rand({13, 9}));
381 
382   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
383   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
384   fake_dict.insert("", "");
385   compile_spec.insert("forward", fake_dict);
386   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
387   // lowered module
388   auto lm = torch::jit::detail::codegen_backend_module(
389       "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
390 
391   std::stringstream ss;
392   lm._save_for_mobile(ss, ExtraFilesMap(), true);
393   auto mlm = _load_for_mobile(ss);
394   std::string error_pattern = R"(
395   Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
396 Traceback of TorchScript (most recent call last):
397   File "<string>", line 3, in <unknown>
398 
399             def forward(self, x: Tensor, h: Tensor):
400                 return self.__loweredModule__.forward(x, h)
401                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
402 
403   File "<string>", line 5, in forward
404                 typed_inputs: List[Any] = [x, h, ]
405                 if self.__backend.is_available() :
406                   _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
407                         ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
408                   assert isinstance(_0, Tensor)
409                   return _0
410   File "<string>", line 3, in <unknown>
411 
412     def forward(self, x, h):
413         return x + h
414                ~~~~~ <--- HERE
415   )";
416   ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
417 }
418 
TEST(BackendTestDebugInfo,TestCompilerWithStringTable)419 TEST(BackendTestDebugInfo, TestCompilerWithStringTable) {
420   setShouldUseFormatWithStringTable(true);
421   Module m("m");
422   m.define(R"(
423     def forward(self, x, h):
424         return x + h
425   )");
426 
427   std::vector<IValue> inputs;
428   inputs.emplace_back(torch::rand({2, 4}));
429   inputs.emplace_back(torch::rand({13, 9}));
430 
431   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
432   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
433   fake_dict.insert("", "");
434   compile_spec.insert("forward", fake_dict);
435   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
436   // lowered module
437   auto lm = torch::jit::detail::codegen_backend_module(
438       "backend_with_compiler_demo", m, compile_spec, any_dict_ty);
439 
440   std::stringstream ss;
441   lm._save_for_mobile(ss, ExtraFilesMap(), true);
442   auto mlm = _load_for_mobile(ss);
443   std::string error_pattern = R"(
444   Module hierarchy:top(m)::<unknown>.__loweredModule__(m)::forward.aten::add
445 Traceback of TorchScript (most recent call last):
446   File "<string>", line 3, in <unknown>
447 
448             def forward(self, x: Tensor, h: Tensor):
449                 return self.__loweredModule__.forward(x, h)
450                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
451 
452   File "<string>", line 5, in forward
453                 typed_inputs: List[Any] = [x, h, ]
454                 if self.__backend.is_available() :
455                   _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
456                         ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
457                   assert isinstance(_0, Tensor)
458                   return _0
459   File "<string>", line 3, in <unknown>
460 
461     def forward(self, x, h):
462         return x + h
463                ~~~~~ <--- HERE
464   )";
465   setShouldUseFormatWithStringTable(false);
466   ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
467 }
468 
TEST(BackendTestDebugInfo,TestExceptionStackForCompilerWithModuleHierarchy)469 TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithModuleHierarchy) {
470   Module a("A");
471   a.define(R"(
472     def forward(self, x, y):
473       return x + y
474   )");
475   Module b("B");
476   b.define(R"(
477     def forward(self, x):
478       return x + 2
479   )");
480   Module c("C");
481   c.register_module("A0", a);
482   c.register_module("B0", b);
483   c.define(R"(
484     def forward(self, x, y):
485       return self.A0.forward(x, y) + self.B0.forward(x)
486   )");
487 
488   std::vector<IValue> inputs;
489   inputs.emplace_back(torch::rand({2, 4}));
490   inputs.emplace_back(torch::rand({13, 9}));
491 
492   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
493   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
494   fake_dict.insert("", "");
495   compile_spec.insert("forward", fake_dict);
496   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
497   // lowered module
498   auto lm = torch::jit::detail::codegen_backend_module(
499       "backend_with_compiler_demo", c, compile_spec, any_dict_ty);
500 
501   std::stringstream ss;
502   lm._save_for_mobile(ss, ExtraFilesMap(), true);
503   auto mlm = _load_for_mobile(ss);
504   std::string error_pattern = R"(
505   Module hierarchy:top(C)::<unknown>.__loweredModule__(C)::forward.A0(A)::forward.aten::add
506 Traceback of TorchScript (most recent call last):
507   File "<string>", line 3, in <unknown>
508 
509             def forward(self, x: Tensor, y: Tensor):
510                 return self.__loweredModule__.forward(x, y)
511                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
512 
513   File "<string>", line 5, in forward
514                 typed_inputs: List[Any] = [x, y, ]
515                 if self.__backend.is_available() :
516                   _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
517                         ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
518                   assert isinstance(_0, Tensor)
519                   return _0
520   File "<string>", line 3, in <unknown>
521 
522     def forward(self, x, y):
523       return self.A0.forward(x, y) + self.B0.forward(x)
524              ~~~~~~~~~~~~~~~ <--- HERE
525 
526   File "<string>", line 3, in forward
527 
528     def forward(self, x, y):
529       return x + y
530              ~~~~~ <--- HERE
531   )";
532   ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
533 }
534 
TEST(BackendTestDebugInfo,TestExceptionStackForCompilerWithTwoLevelModuleHierarchy)535 TEST(
536     BackendTestDebugInfo,
537     TestExceptionStackForCompilerWithTwoLevelModuleHierarchy) {
538   Module a("A");
539   a.define(R"(
540     def forward(self, x, y):
541       return x + y
542   )");
543   Module b("B");
544   b.register_module("A0", a);
545   b.define(R"(
546     def forward(self, x, y):
547       return self.A0.forward(x, y) + 2
548   )");
549   Module c("C");
550   c.register_module("B0", b);
551   c.define(R"(
552     def forward(self, x, y):
553       return self.B0.forward(x, y) + 3
554   )");
555 
556   std::vector<IValue> inputs;
557   inputs.emplace_back(torch::rand({2, 4}));
558   inputs.emplace_back(torch::rand({13, 9}));
559 
560   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
561   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
562   fake_dict.insert("", "");
563   compile_spec.insert("forward", fake_dict);
564   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
565   // lowered module
566   auto lm = torch::jit::detail::codegen_backend_module(
567       "backend_with_compiler_demo", c, compile_spec, any_dict_ty);
568 
569   std::stringstream ss;
570   lm._save_for_mobile(ss, ExtraFilesMap(), true);
571   auto mlm = _load_for_mobile(ss);
572   /*
573    * Error stack throw will look like this:
574    * Module hierarchy:top(backend_with_compiler_demoLoweredModule).B0(B).A0(A)
575    * Traceback of TorchScript (most recent call last):
576    * File "<string>", line 5, in FunctionName_UNKNOWN
577    *               typed_inputs: List[Any] = [x, y, ]
578    *               if self.__backend.is_available() :
579    *                 _0, = self.__backend.execute(self.__handles["forward"],
580    * typed_inputs)
581    *                       ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
582    *                 assert isinstance(_0, Tensor)
583    *                 return _0
584    *  File "<string>", line 3, in FunctionName_UNKNOWN
585    *
586    *    def forward(self, x, y):
587    *      return self.B0.forward(x, y) + 3
588    *             ~~~~~~~~~~~~~~~ <--- HERE
589    *
590    *  File "<string>", line 3, in FunctionName_UNKNOWN
591    *
592    *    def forward(self, x, y):
593    *      return self.A0.forward(x, y) + 2
594    *             ~~~~~~~~~~~~~~~ <--- HERE
595    *
596    *  File "<string>", line 3, in FunctionName_UNKNOWN
597    *
598    *    def forward(self, x, y):
599    *      return x + y
600    *             ~~~~~ <--- HERE
601    *
602    */
603   std::string error_pattern = R"(
604   Module hierarchy:top(C)::<unknown>.__loweredModule__(C)::forward.B0(B)::forward.A0(A)::forward.aten::add
605 Traceback of TorchScript (most recent call last):
606   File "<string>", line 3, in <unknown>
607 
608             def forward(self, x: Tensor, y: Tensor):
609                 return self.__loweredModule__.forward(x, y)
610                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
611 
612   File "<string>", line 5, in forward
613                 typed_inputs: List[Any] = [x, y, ]
614                 if self.__backend.is_available() :
615                   _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
616                         ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
617                   assert isinstance(_0, Tensor)
618                   return _0
619   File "<string>", line 3, in <unknown>
620 
621     def forward(self, x, y):
622       return self.B0.forward(x, y) + 3
623              ~~~~~~~~~~~~~~~ <--- HERE
624 
625   File "<string>", line 3, in forward
626 
627     def forward(self, x, y):
628       return self.A0.forward(x, y) + 2
629              ~~~~~~~~~~~~~~~ <--- HERE
630 
631   File "<string>", line 3, in forward
632 
633     def forward(self, x, y):
634       return x + y
635              ~~~~~ <--- HERE
636   )";
637   ASSERT_THROWS_WITH_MESSAGE(mlm.forward(inputs), error_pattern);
638 }
639 
TEST(BackendTestDebugInfo,TestExceptionStackForCompilerWithLoweredSubModule)640 TEST(BackendTestDebugInfo, TestExceptionStackForCompilerWithLoweredSubModule) {
641   std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
642   Module a("A");
643   a.define(R"(
644     def forward(self, x, y):
645       return x + y
646   )");
647   Module b("B");
648   b.define(R"(
649     def forward(self, x):
650       return x + 2
651   )");
652   Module c("C");
653   c.register_module("A0", a);
654   c.register_module("B0", b);
655   c.define(R"(
656     def forward(self, x, y):
657       return self.A0.forward(x, y) + self.B0.forward(x)
658   )");
659 
660   std::vector<IValue> inputs;
661   inputs.emplace_back(torch::rand({2, 4}));
662   inputs.emplace_back(torch::rand({13, 9}));
663 
664   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
665   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
666   fake_dict.insert("", "");
667   compile_spec.insert("forward", fake_dict);
668   IValue submodule = c.attr("A0");
669   Module current_sm = submodule.toModule();
670   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
671   // lowered module
672   auto lowered_submodule = torch::jit::detail::codegen_backend_module(
673       "backend_with_compiler_demo", current_sm, compile_spec, any_dict_ty);
674 
675   c.type()->unsafeChangeAttributeType("A0", lowered_submodule.type());
676   c.setattr("A0", lowered_submodule._ivalue());
677   std::unordered_map<TypePtr, TypePtr> type_remap;
678   type_remap[a.type()] = lowered_submodule.type();
679   auto type_remap_fn = [&type_remap](TypePtr in) {
680     auto it = type_remap.find(in);
681     if (it == type_remap.end())
682       return in;
683     return it->second;
684   };
685   for (auto& fn : c.type()->methods()) {
686     auto method = c.get_method(fn->name());
687     auto graph = method.graph();
688     graph->remapTypes(type_remap_fn);
689     auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
690     fn->setSchema(new_schema);
691   }
692 
693   std::stringstream ss;
694   c._save_for_mobile(ss, ExtraFilesMap(), true);
695   auto c_loaded = _load_for_mobile(ss);
696   std::string error_pattern = R"(
697   Module hierarchy:top(C)::<unknown>.A0(A)::forward.__loweredModule__(A)::forward.aten::add
698 Traceback of TorchScript (most recent call last):
699   File "<string>", line 3, in <unknown>
700 
701     def forward(self, x, y):
702       return self.A0.forward(x, y) + self.B0.forward(x)
703              ~~~~~~~~~~~~~~~ <--- HERE
704 
705   File "<string>", line 3, in forward
706 
707             def forward(self, x: Tensor, y: Tensor):
708                 return self.__loweredModule__.forward(x, y)
709                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
710 
711   File "<string>", line 5, in forward
712                 typed_inputs: List[Any] = [x, y, ]
713                 if self.__backend.is_available() :
714                   _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
715                         ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
716                   assert isinstance(_0, Tensor)
717                   return _0
718   File "<string>", line 3, in <unknown>
719 
720     def forward(self, x, y):
721       return x + y
722              ~~~~~ <--- HERE
723   )";
724   ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
725 }
726 
TEST(BackendTestDebugInfo,TestExceptionStackForCompilerWithSelectiveLoweredSubModule)727 TEST(
728     BackendTestDebugInfo,
729     TestExceptionStackForCompilerWithSelectiveLoweredSubModule) {
730   std::shared_ptr<CompilationUnit> cu = std::make_shared<CompilationUnit>();
731   Module aa("AA");
732   aa.define(R"(
733     def forward(self, x, y):
734       return x + y
735   )");
736   Module a("A");
737   a.register_module("AA0", aa);
738   a.define(R"(
739     def forward(self, x, y):
740       return self.AA0.forward(x, y) + 3
741   )");
742   Module b("B");
743   b.define(R"(
744     def forward(self, x):
745       return x + 2
746   )");
747   Module c("C");
748   c.register_module("A0", a);
749   c.register_module("B0", b);
750   c.define(R"(
751     def forward(self, x, y):
752       return self.A0.forward(x, y) + self.B0.forward(x)
753   )");
754 
755   std::vector<IValue> inputs;
756   inputs.emplace_back(torch::rand({2, 4}));
757   inputs.emplace_back(torch::rand({13, 9}));
758 
759   c10::Dict<IValue, IValue> compile_spec(StringType::get(), AnyType::get());
760   c10::Dict<IValue, IValue> fake_dict(StringType::get(), AnyType::get());
761   fake_dict.insert("", "");
762   compile_spec.insert("forward", fake_dict);
763   IValue submodule = c.attr("A0");
764   Module current_sm = submodule.toModule();
765   auto any_dict_ty = DictType::create(StringType::get(), AnyType::get());
766   // lowered module
767   auto lowered_submodule = torch::jit::detail::codegen_backend_module(
768       "backend_with_compiler_demo", current_sm, compile_spec, any_dict_ty);
769 
770   c.type()->unsafeChangeAttributeType("A0", lowered_submodule.type());
771   c.setattr("A0", lowered_submodule._ivalue());
772   std::unordered_map<TypePtr, TypePtr> type_remap;
773   type_remap[a.type()] = lowered_submodule.type();
774   auto type_remap_fn = [&type_remap](TypePtr in) {
775     auto it = type_remap.find(in);
776     if (it == type_remap.end())
777       return in;
778     return it->second;
779   };
780   for (auto& fn : c.type()->methods()) {
781     auto method = c.get_method(fn->name());
782     auto graph = method.graph();
783     graph->remapTypes(type_remap_fn);
784     auto new_schema = fn->getSchema().cloneWithRemappedTypes(type_remap_fn);
785     fn->setSchema(new_schema);
786   }
787 
788   std::stringstream ss;
789   c._save_for_mobile(ss, ExtraFilesMap(), true);
790   auto c_loaded = _load_for_mobile(ss);
791   /*
792    * Erro stack trace will look like this:
793    * Module hierarchy:top(C).A0(backend_with_compiler_demoLoweredModule).AA0(AA)
794    * Traceback of TorchScript (most recent call last):
795    *  File "<string>", line 3, in FunctionName_UNKNOWN
796    *
797    *    def forward(self, x, y):
798    *      return self.A0.forward(x, y) + self.B0.forward(x)
799    *             ~~~~~~~~~~~~~~~ <--- HERE
800    *
801    *  File "<string>", line 5, in FunctionName_UNKNOWN
802    *                typed_inputs: List[Any] = [x, y, ]
803    *                if self.__backend.is_available() :
804    *                  _0, = self.__backend.execute(self.__handles["forward"],
805    * typed_inputs)
806    *                        ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
807    *                  assert isinstance(_0, Tensor)
808    *                  return _0
809    *  File "<string>", line 3, in FunctionName_UNKNOWN
810    *
811    *    def forward(self, x, y):
812    *      return self.AA0.forward(x, y) + 3
813    *             ~~~~~~~~~~~~~~~~ <--- HERE
814    *
815    *  File "<string>", line 3, in FunctionName_UNKNOWN
816    *
817    *    def forward(self, x, y):
818    *      return x + y
819    *             ~~~~~ <--- HERE
820    *
821    *
822    *  */
823   std::string error_pattern = R"(
824   Module hierarchy:top(C)::<unknown>.A0(A)::forward.__loweredModule__(A)::forward.AA0(AA)::forward.aten::add
825 Traceback of TorchScript (most recent call last):
826   File "<string>", line 3, in <unknown>
827 
828     def forward(self, x, y):
829       return self.A0.forward(x, y) + self.B0.forward(x)
830              ~~~~~~~~~~~~~~~ <--- HERE
831 
832   File "<string>", line 3, in forward
833 
834             def forward(self, x: Tensor, y: Tensor):
835                 return self.__loweredModule__.forward(x, y)
836                        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
837 
838   File "<string>", line 5, in forward
839                 typed_inputs: List[Any] = [x, y, ]
840                 if self.__backend.is_available() :
841                   _0, = self.__backend.execute(self.__handles["forward"], typed_inputs)
842                         ~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
843                   assert isinstance(_0, Tensor)
844                   return _0
845   File "<string>", line 3, in <unknown>
846 
847     def forward(self, x, y):
848       return self.AA0.forward(x, y) + 3
849              ~~~~~~~~~~~~~~~~ <--- HERE
850 
851   File "<string>", line 3, in forward
852 
853     def forward(self, x, y):
854       return x + y
855              ~~~~~ <--- HERE
856   )";
857   ASSERT_THROWS_WITH_MESSAGE(c_loaded.forward(inputs), error_pattern);
858 }
859 
860 } // namespace jit
861 } // namespace torch
862