xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_lite_interpreter_direct.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/jit/test_utils.h>
2 
3 #include <gtest/gtest.h>
4 
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/frontend/resolver.h>
9 #include <torch/csrc/jit/mobile/compatibility/backport.h>
10 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
11 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
12 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
13 #include <torch/csrc/jit/mobile/import.h>
14 #include <torch/csrc/jit/mobile/interpreter.h>
15 #include <torch/csrc/jit/mobile/module.h>
16 #include <torch/csrc/jit/mobile/parse_bytecode.h>
17 #include <torch/csrc/jit/mobile/parse_operators.h>
18 #include <torch/csrc/jit/serialization/export.h>
19 #include <torch/csrc/jit/serialization/export_bytecode.h>
20 #include <torch/csrc/jit/serialization/import.h>
21 #include <torch/custom_class.h>
22 #include <torch/torch.h>
23 
24 #include <unordered_set>
25 
26 // Tests go in torch::jit
27 namespace torch {
28 namespace jit {
29 
TEST(LiteInterpreterDirectTest,UpsampleNearest2d)30 TEST(LiteInterpreterDirectTest, UpsampleNearest2d) {
31   Module m("m");
32   m.define(R"(
33     def forward(self, input: Tensor, scale:float):
34       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
35   )");
36 
37   std::vector<IValue> inputs;
38   inputs.emplace_back(torch::rand({1, 3, 128, 128}));
39   inputs.emplace_back(at::Scalar(2.0));
40   auto ref = m.forward(inputs);
41 
42   CompilationOptions options;
43   mobile::Module bc = jitModuleToMobile(m, options);
44   IValue res;
45   res = bc.forward(inputs);
46 
47   auto resd = res.toTensor();
48   auto refd = ref.toTensor();
49   ASSERT_TRUE(resd.equal(refd));
50 }
51 
TEST(LiteInterpreterDirectTest,CheckAttrAccess)52 TEST(LiteInterpreterDirectTest, CheckAttrAccess) {
53   Module m("m");
54   m.register_attribute("mobile_optimized", BoolType::get(), true);
55 
56   CompilationOptions options;
57   mobile::Module bc = jitModuleToMobile(m, options);
58   bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
59 
60   AT_ASSERT(mobile_optimized);
61   m.setattr("mobile_optimized", false);
62   bc = jitModuleToMobile(m, options);
63   mobile_optimized = bc.attr("mobile_optimized", false).toBool();
64   AT_ASSERT(!mobile_optimized);
65 }
66 
TEST(LiteInterpreterDirectTest,MethodInvocation)67 TEST(
68     LiteInterpreterDirectTest,
69     MethodInvocation) { // NOLINT (use =delete in gtest)
70   const std::vector<std::string> test_programs{
71       // test invoking a method with default parameter
72       R"(
73       def test_func(self, x, b : int = 4):
74         return self.foo + x + b
75       )",
76       // inner method call with default parameter (gets inlined)
77       R"(
78       def add_with_default_arg(self, x, b : int = 4):
79         return self.foo + x + b
80       def test_func(self, x):
81         return self.add_with_default_arg(x)  # invoke method w/ default arg
82       )",
83       // simple method call
84       R"(
85       def test_func(self, x):
86         b = 4
87         return self.foo + x + b
88       )",
89   };
90   for (const auto& test_program : test_programs) {
91     Module m("m");
92     m.register_parameter("foo", torch::ones({}), false);
93     m.define(test_program);
94 
95     const int fortyTwo = 42; // (keep linter happy)
96     auto minput = fortyTwo * torch::ones({});
97     auto ref = m.run_method("test_func", minput);
98 
99     CompilationOptions options;
100     mobile::Module bc = jitModuleToMobile(m, options);
101     const auto& test_func = bc.get_method("test_func");
102     std::cerr << "hello " << std::endl;
103     IValue res;
104     for (int i = 0; i < 3; ++i) {
105       res = test_func({minput});
106     }
107     std::cerr << "hello 3" << std::endl;
108 
109     auto resd = res.toTensor().item<float>();
110     auto refd = ref.toTensor().item<float>();
111     AT_ASSERT(resd == refd);
112   }
113 }
114 
TEST(LiteInterpreterDirectTest,Conv)115 TEST(LiteInterpreterDirectTest, Conv) {
116   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
117   if (s && strcmp(s, "1") == 0)
118     return;
119 
120   std::vector<torch::jit::IValue> inputs;
121 
122   Module m("m");
123   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
124   m.register_parameter("bias", torch::ones({20}), false);
125   m.define(R"(
126     def forward(self, input):
127       return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
128   )");
129 
130   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
131   inputs.push_back(torch::ones({1, 1, 28, 28}));
132 
133   auto outputref = m.forward(inputs).toTensor();
134 
135   CompilationOptions options;
136   mobile::Module bc = jitModuleToMobile(m, options);
137   IValue res;
138   for (int i = 0; i < 3; ++i) {
139     res = bc.get_method("forward")(inputs);
140   }
141   auto output = res.toTensor();
142   AT_ASSERT(outputref.dim() == output.dim());
143   AT_ASSERT(
144       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
145 }
146 
TEST(LiteInterpreterDirectTest,Inline)147 TEST(LiteInterpreterDirectTest, Inline) {
148   Module m("m");
149   m.define(R"JIT(
150   def foo1(self, x):
151       return x + 1
152 
153   def foo2(self, x):
154       return self.foo1(x) + 2
155 
156   def foo3(self, x):
157       return self.foo2(x) + 3
158   )JIT");
159   CompilationOptions options;
160   mobile::Module bc = jitModuleToMobile(m, options);
161   std::vector<torch::jit::IValue> inputs({torch::ones({})});
162   auto output = bc.get_method("foo3")(inputs);
163   AT_ASSERT(output.toTensor().item<float>() == 7.0);
164 }
165 
TEST(LiteInterpreterDirectTest,Tuple)166 TEST(LiteInterpreterDirectTest, Tuple) {
167   Module m("m");
168   m.define(R"JIT(
169   def foo(self, x):
170       return (1, 2, x + 3)
171 
172   def forward(self, x):
173       tuple = self.foo(x)
174       return tuple
175   )JIT");
176   CompilationOptions options;
177   mobile::Module bc = jitModuleToMobile(m, options);
178   std::vector<torch::jit::IValue> inputs({torch::ones({})});
179   auto output = bc.get_method("forward")(inputs);
180   AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
181 }
182 
TEST(LiteInterpreterDirectTest,Dict)183 TEST(LiteInterpreterDirectTest, Dict) {
184   Module m("m");
185   m.define(R"JIT(
186   def foo(self, x):
187       return {"result": x + 1}
188 
189   def forward(self, x):
190       d = self.foo(x)
191       return d
192   )JIT");
193   CompilationOptions options;
194   mobile::Module bc = jitModuleToMobile(m, options);
195   std::vector<torch::jit::IValue> inputs({torch::ones({})});
196   auto output = bc.get_method("forward")(inputs);
197   AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
198 }
199 
TEST(LiteInterpreterDirectTest,Prim)200 TEST(LiteInterpreterDirectTest, Prim) {
201   Module m("m");
202   m.define(R"JIT(
203         def forward(self, x):
204             return int(x)
205   )JIT");
206 
207   std::vector<IValue> inputs;
208   auto minput = 3.5 * torch::ones({});
209   inputs.emplace_back(minput);
210   auto ref = m.run_method("forward", minput);
211 
212   CompilationOptions options;
213   mobile::Module bc = jitModuleToMobile(m, options);
214 
215   IValue res;
216   for (int i = 0; i < 3; ++i) {
217     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
218     auto bcinputs = inputs;
219     res = bc.get_method("forward")(bcinputs);
220   }
221 
222   auto resi = res.toInt();
223   auto refi = ref.toInt();
224   AT_ASSERT(resi == refi);
225 }
226 
TEST(LiteInterpreterDirectTest,PrimScalar)227 TEST(LiteInterpreterDirectTest, PrimScalar) {
228   Module m("m");
229   m.define(R"JIT(
230         def forward(self, x):
231             return int(x.item())
232   )JIT");
233 
234   std::vector<IValue> inputs;
235   auto minput = 3.5 * torch::ones({});
236   inputs.emplace_back(minput);
237   auto ref = m.run_method("forward", minput);
238 
239   CompilationOptions options;
240   mobile::Module bc = jitModuleToMobile(m, options);
241   IValue res;
242   for (int i = 0; i < 3; ++i) {
243     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
244     auto bcinputs = inputs;
245     res = bc.get_method("forward")(bcinputs);
246   }
247 
248   auto resi = res.toInt();
249   auto refi = ref.toInt();
250   AT_ASSERT(resi == refi);
251 }
252 
TEST(LiteInterpreterDirectTest,WrongMethodName)253 TEST(LiteInterpreterDirectTest, WrongMethodName) {
254   Module m("m");
255   m.register_parameter("foo", torch::ones({}), false);
256   m.define(R"(
257     def add(self, x):
258       b = 4
259       return self.foo + x + b
260   )");
261   CompilationOptions options;
262   mobile::Module bc = jitModuleToMobile(m, options);
263   std::vector<IValue> inputs;
264   auto minput = 5 * torch::ones({});
265   inputs.emplace_back(minput);
266   ASSERT_THROWS_WITH_MESSAGE(
267       bc.get_method("forward")(inputs), "is not defined");
268 }
269 
TEST(LiteInterpreterDirectTest,SetState)270 TEST(LiteInterpreterDirectTest, SetState) {
271   Module m("m");
272   m.register_parameter("foo", torch::ones({}), false);
273   m.define(R"(
274     def __getstate__(self):
275       return self.foo
276     def __setstate__(self, a):
277       self.foo = a
278     def forward(self, x):
279       b = 4
280       return self.foo + x + b
281   )");
282 
283   std::vector<IValue> inputs;
284   auto minput = 5 * torch::ones({});
285   inputs.emplace_back(minput);
286 
287   std::stringstream ms;
288   m.save(ms);
289   auto loaded_m = load(ms);
290   auto ref = loaded_m.run_method("forward", minput);
291 
292   CompilationOptions options;
293   mobile::Module bc = jitModuleToMobile(m, options);
294   IValue res;
295   for (int i = 0; i < 3; ++i) {
296     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
297     auto bcinputs = inputs;
298     res = bc.get_method("forward")(bcinputs);
299   }
300 
301   auto resd = res.toTensor().item<float>();
302   auto refd = ref.toTensor().item<float>();
303   AT_ASSERT(resd == refd);
304 }
305 
306 class TorchBindLiteInterpreterDirectTestStruct
307     : public torch::jit::CustomClassHolder {
308  public:
get(at::Tensor t)309   std::string get(at::Tensor t) {
310     std::stringstream ss;
311     ss << "Hello! Your tensor has ";
312     ss << t.numel();
313     ss << " elements!";
314     return ss.str();
315   }
316 };
317 
318 namespace {
319 struct ClassNamespaceValue : public SugaredValue {
ClassNamespaceValuetorch::jit::__anon3fd744610111::ClassNamespaceValue320   explicit ClassNamespaceValue(c10::QualifiedName name)
321       : basename_(std::move(name)) {}
322 
attrtorch::jit::__anon3fd744610111::ClassNamespaceValue323   std::shared_ptr<SugaredValue> attr(
324       const SourceRange&,
325       GraphFunction&,
326       const std::string& name) override {
327     const auto fullName = c10::QualifiedName(basename_, name);
328 
329     // Check to see if it is a custom class.
330     if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
331       return std::make_shared<ClassValue>(custom_class);
332     }
333 
334     // If it's not a custom class, assume it's another namespace
335     // NOLINTNEXTLINE(performance-move-const-arg)
336     return std::make_shared<ClassNamespaceValue>(fullName);
337   }
338 
kindtorch::jit::__anon3fd744610111::ClassNamespaceValue339   std::string kind() const override {
340     return "Class Namespace";
341   }
342 
343  private:
344   c10::QualifiedName basename_;
345 };
346 
347 struct TestModuleResolver : public Resolver {
resolveValuetorch::jit::__anon3fd744610111::TestModuleResolver348   std::shared_ptr<SugaredValue> resolveValue(
349       const std::string& name,
350       GraphFunction&,
351       const SourceRange&) override {
352     if (name == "torch") {
353       return std::make_shared<BuiltinModule>("aten");
354     } else if (name == "__torch__") {
355       return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
356     }
357 
358     return nullptr;
359   }
360 
resolveTypetorch::jit::__anon3fd744610111::TestModuleResolver361   TypePtr resolveType(const std::string&, const SourceRange&) override {
362     return nullptr;
363   }
364 };
365 } // namespace
366 
TEST(LiteInterpreterDirectTest,BuiltinFunction)367 TEST(LiteInterpreterDirectTest, BuiltinFunction) {
368   script::Module m("m");
369   auto custom_class_obj =
370       make_custom_class<TorchBindLiteInterpreterDirectTestStruct>();
371   m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
372   m.define(R"(
373     def forward(self, x) -> str:
374       return self.my_obj.get(x)
375   )");
376 
377   CompilationOptions options;
378   mobile::Module bc = jitModuleToMobile(m, options);
379   auto res =
380       bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
381   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
382   auto str = res.toStringRef();
383   std::string expected = "Hello! Your tensor has 12 elements!";
384   AT_ASSERT(str == expected);
385 }
386 
387 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterDirectTest,GetRuntimeByteCodeVersion)388 TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) {
389   auto runtime_bytecode_version = _get_runtime_bytecode_version();
390   AT_ASSERT(
391       runtime_bytecode_version ==
392       caffe2::serialize::kMaxSupportedBytecodeVersion);
393 }
394 
TEST(LiteInterpreterDirectTest,GetRuntimeOperatorsVersion)395 TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) {
396   auto runtime_operators_version = _get_runtime_operators_min_max_versions();
397   AT_ASSERT(
398       runtime_operators_version.first ==
399           caffe2::serialize::kMinSupportedFileFormatVersion &&
400       runtime_operators_version.second ==
401           caffe2::serialize::kMaxSupportedFileFormatVersion);
402 }
403 
404 /**
405  * The test below is disarmed for FB internal xplat builds since
406  * BUCK requires us to pass in the script_module_v4.ptl file in
407  * as a resource dependency of the build rule for this file, and
408  * we would need to access it via the C++ Resources API instead
409  * of directly reading from disk (which is what the open source
410  * build/run does).
411  */
TEST(LiteInterpreterDirectTest,GetByteCodeVersion)412 TEST(LiteInterpreterDirectTest, GetByteCodeVersion) {
413   std::string filePath(__FILE__);
414   auto test_model_file_v4 =
415       filePath.substr(0, filePath.find_last_of("/\\") + 1);
416   test_model_file_v4.append("script_module_v4.ptl");
417 
418   auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
419   AT_ASSERT(version_v4 == 4);
420 }
421 
422 #endif // !defined(FB_XPLAT_BUILD)
423 
TEST(LiteInterpreterDirectTest,GetRuntimeOpsAndInfo)424 TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) {
425   auto runtime_ops = _get_runtime_ops_and_info();
426   // Ballpark estimate of the minimal number of ops; just used to
427   // verify API returns a reasonably large number.
428   AT_ASSERT(runtime_ops.size() > 2900);
429 }
430 
TEST(LiteInterpreterDirectTest,Eval)431 TEST(LiteInterpreterDirectTest, Eval) {
432   std::vector<torch::jit::IValue> inputs;
433 
434   Module m("m");
435   m.define(R"(
436     def __init__(self, x):
437       self.training = True
438 
439     def forward(self, input):
440       return torch.dropout(input, 1.0, self.training)
441   )");
442 
443   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
444   inputs.push_back(torch::ones({1, 1, 28, 28}));
445   m.eval();
446   auto outputref = m.forward(inputs).toTensor();
447 
448   // save m in training mode to make sure that mobile eval() will correctly
449   // change back to eval mode
450   m.train();
451   CompilationOptions options;
452   mobile::Module bc = jitModuleToMobile(m, options);
453   bc.eval();
454   IValue res;
455   for (int i = 0; i < 3; ++i) {
456     res = bc.get_method("forward")(inputs);
457   }
458   auto output = res.toTensor();
459   AT_ASSERT(outputref.dim() == output.dim());
460   AT_ASSERT(
461       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
462 }
463 
TEST(LiteInterpreterDirectTest,FindWrongMethodName)464 TEST(LiteInterpreterDirectTest, FindWrongMethodName) {
465   Module m("m");
466   m.register_parameter("foo", torch::ones({}), false);
467   m.define(R"(
468     def add(self, x):
469       b = 4
470       return self.foo + x + b
471   )");
472   CompilationOptions options;
473   mobile::Module bc = jitModuleToMobile(m, options);
474   ASSERT_TRUE(bc.find_method("forward") == std::nullopt);
475 }
476 
TEST(LiteInterpreterDirectTest,FindAndRunMethod)477 TEST(LiteInterpreterDirectTest, FindAndRunMethod) {
478   Module m("m");
479   m.register_parameter("foo", torch::ones({}), false);
480   m.define(R"(
481     def add_it(self, x):
482       b = 4
483       return self.foo + x + b
484   )");
485 
486   std::vector<IValue> inputs;
487   auto minput = 5 * torch::ones({});
488   inputs.emplace_back(minput);
489   auto ref = m.get_method("add_it")(inputs);
490 
491   CompilationOptions options;
492   mobile::Module bc = jitModuleToMobile(m, options);
493   IValue res;
494   for (int i = 0; i < 3; ++i) {
495     auto bcinputs = inputs;
496     auto method = bc.find_method("add_it");
497     AT_ASSERT(method != std::nullopt);
498     res = (*method)(std::move(bcinputs));
499   }
500 
501   auto resd = res.toTensor().item<float>();
502   auto refd = ref.toTensor().item<float>();
503   AT_ASSERT(resd == refd);
504 }
505 
TEST(LiteInterpreterDirectTest,RunMethodVariadic)506 TEST(LiteInterpreterDirectTest, RunMethodVariadic) {
507   Module m("m");
508   m.register_parameter("foo", torch::ones({}), false);
509   m.define(R"(
510     def add_three(self, x, y):
511       return self.foo + x + y
512   )");
513 
514   std::vector<IValue> inputs;
515   auto inputx = 5 * torch::ones({});
516   auto inputy = 4 * torch::ones({});
517   auto ref = m.run_method("add_three", inputx, inputy);
518 
519   CompilationOptions options;
520   mobile::Module bc = jitModuleToMobile(m, options);
521   IValue res = bc.run_method("add_three", inputx, inputy);
522 
523   auto resd = res.toTensor().item<float>();
524   auto refd = ref.toTensor().item<float>();
525   AT_ASSERT(resd == refd);
526 }
527 
TEST(LiteInterpreterDirectTest,DuplicateSetState)528 TEST(LiteInterpreterDirectTest, DuplicateSetState) {
529   Module m("M");
530   m.register_parameter("foo", torch::ones({}), false);
531   m.define(R"(
532     def __getstate__(self):
533       return self.foo + self.foo
534     def __setstate__(self, a):
535       self.foo = a
536     def forward(self, x):
537       b = 4
538       return self.foo + x + b
539   )");
540 
541   Module b("B");
542   b.register_module("M0", m);
543   b.register_module("M1", m);
544   b.define(R"(
545     def forward(self, x):
546       return self.M0.forward(x) + self.M1.forward(x)
547   )");
548 
549   CompilationOptions options;
550   mobile::Module bc = jitModuleToMobile(m, options);
551   const auto methods = bc.get_methods();
552   const size_t expected_n = 3;
553   ASSERT_EQ(methods.size(), expected_n);
554 }
555 
TEST(LiteInterpreterDirectTest,OpNameExportFetchRootOperators)556 TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) {
557   torch::jit::Module m("m");
558   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
559   m.register_parameter("bias", torch::ones({20}), false);
560   m.define(R"(
561     def forward(self, input):
562       x1 = torch.zeros(2, 2)
563       x2 = torch.empty_like(torch.empty(2, 2))
564       x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
565       return (x1, x2, x3)
566   )");
567   m.eval();
568 
569   CompilationOptions options;
570   mobile::Module ptl_model = jitModuleToMobile(m, options);
571   std::set<std::string> operator_names =
572       torch::jit::mobile::_export_operator_list(ptl_model);
573   std::set<std::string> expected_operator_names = {
574       "aten::_convolution",
575       "aten::empty.memory_format",
576       "aten::empty_like",
577       "aten::zeros",
578   };
579   EXPECT_EQ(operator_names, expected_operator_names)
580       << "Expected the root operator lists to be the same";
581 }
582 
TEST(LiteInterpreterDirectTest,DefaultArgsConv)583 TEST(LiteInterpreterDirectTest, DefaultArgsConv) {
584   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
585   if (s && strcmp(s, "1") == 0)
586     return;
587 
588   std::vector<torch::jit::IValue> inputs;
589 
590   Module m("m");
591   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
592   m.register_parameter("bias", torch::ones({20}), false);
593   m.define(R"(
594     def forward(self, input):
595       return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
596   )");
597 
598   inputs.emplace_back(torch::ones({1, 1, 28, 28}));
599 
600   auto outputref = m.forward(inputs).toTensor();
601 
602   CompilationOptions options;
603   mobile::Module bc = jitModuleToMobile(m, options);
604   IValue res;
605   for (int i = 0; i < 1; ++i) {
606     res = bc.get_method("forward")(inputs);
607   }
608   auto output = res.toTensor();
609   AT_ASSERT(outputref.dim() == output.dim());
610   AT_ASSERT(output.equal(outputref));
611 }
612 
613 namespace {
testLiteModuleCompareResultTensors(Module & m,const std::vector<torch::jit::IValue> & inputs,const std::string & method_name="forward")614 void testLiteModuleCompareResultTensors(
615     Module& m,
616     const std::vector<torch::jit::IValue>& inputs,
617     const std::string& method_name = "forward") {
618   auto outputref = m.get_method(method_name)(inputs).toTensor();
619 
620   CompilationOptions options;
621   mobile::Module bc = jitModuleToMobile(m, options);
622   IValue res;
623   for (int i = 0; i < 3; ++i) {
624     res = bc.get_method(method_name)(inputs);
625   }
626   auto output = res.toTensor();
627   AT_ASSERT(outputref.dim() == output.dim());
628   AT_ASSERT(output.equal(outputref));
629 }
630 
testDefaultArgsPinv2(int num_args)631 void testDefaultArgsPinv2(int num_args) {
632   Module m("m");
633   if (num_args == 1) {
634     m.define(R"(
635       def forward(self, input):
636         return torch.linalg_pinv(input)
637     )");
638   } else if (num_args == 2) {
639     m.define(R"(
640       def forward(self, input):
641         return torch.linalg_pinv(input, 1e-5)
642     )");
643   } else if (num_args == 3) {
644     m.define(R"(
645       def forward(self, input):
646         return torch.linalg_pinv(input, 1e-5, True)
647     )");
648   }
649 
650   std::vector<torch::jit::IValue> inputs;
651   const int N = 28;
652   auto input = torch::range(1, N * N, 1);
653   input[0] = 1; // a more stable matrix
654   input = input.view({N, N});
655   inputs.emplace_back(input);
656   testLiteModuleCompareResultTensors(m, inputs);
657 }
658 } // namespace
659 
660 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterDirectTest,DefaultArgsPinv)661 TEST(LiteInterpreterDirectTest, DefaultArgsPinv) {
662   // Test with different number of specified arguments.
663   // Arguments not specified take default value.
664   for (int num_args = 1; num_args <= 3; ++num_args) {
665     testDefaultArgsPinv2(num_args);
666   }
667 
668   //  bytecode with one specified argument:
669   //  (6,
670   //      ('__torch__.m.forward',
671   //          (('instructions',
672   //              (('STOREN', 1, 2),
673   //                  ('DROPR', 1, 0),
674   //                  ('MOVE', 2, 0),
675   //                  ('OP', 0, 0),
676   //                  ('RET', 0, 0))),
677   //              ('operators', (('aten::linalg_pinv', '', 1),)),
678   //              ('constants', (False, 1e-15)), # default constants are not
679   //              used
680   //              ('types', ()),
681   //              ('register_size', 2)),
682   //          (('arguments',
683   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
684   //              None)),
685   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
686   //                  None)))),
687   //              ('returns',
688   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
689   //                  None)),)))))
690 
691   //  bytecode with 2 specified argument:
692   //  (6,
693   //      ('__torch__.m.forward',
694   //          (('instructions',
695   //              (('STOREN', 1, 2),
696   //                  ('DROPR', 1, 0),
697   //                  ('MOVE', 2, 0),
698   //                  ('LOADC', 1, 0), # added LOADC for specified argument
699   //                  ('OP', 0, 0),
700   //                  ('RET', 0, 0))),
701   //              ('operators', (('aten::linalg_pinv', '', 2),)),
702   //              ('constants', (False, 1e-05)), # updated constant table
703   //              ('types', ()),
704   //              ('register_size', 2)),
705   //          (('arguments',
706   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
707   //              None)),
708   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
709   //                  None)))),
710   //              ('returns',
711   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
712   //                  None)),)))))
713 
714   //  bytecode with 3 specified arguments:
715   //  (6,
716   //      ('__torch__.m.forward',
717   //          (('instructions',
718   //              (('STOREN', 1, 2),
719   //                  ('DROPR', 1, 0),
720   //                  ('MOVE', 2, 0),
721   //                  ('LOADC', 1, 0),
722   //                  ('LOADC', 0, 0),
723   //                  ('OP', 0, 0),
724   //                  ('RET', 0, 0))),
725   //              ('operators', (('aten::linalg_pinv', '', 3),)),
726   //              ('constants', (True, 1e-05)),
727   //              ('types', ()),
728   //              ('register_size', 2)),
729   //          (('arguments',
730   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
731   //              None)),
732   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
733   //                  None)))),
734   //              ('returns',
735   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
736   //                  None)),)))))
737 }
738 
TEST(LiteInterpreterDirectTest,DefaultArgsTensorinvSpecifyDefault)739 TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) {
740   // The second argument is specified, but the value is the same as the default
741   // value. It's treated as "not specified" since the value can be fetched from
742   // schema.
743   Module m("m");
744   m.define(R"(
745     def forward(self, input):
746       return torch.linalg_tensorinv(input, 2)
747   )");
748   torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
749   auto arg_nums = code.op_to_num_specified_args();
750   ASSERT_EQ(arg_nums.size(), 1);
751   ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
752   std::vector<torch::jit::IValue> inputs;
753   const int N = 4;
754   auto input = torch::rand({N, N, N, N});
755   inputs.emplace_back(input);
756   testLiteModuleCompareResultTensors(m, inputs);
757 }
758 
testDefaultArgsPinvWithOutArg2(int num_args)759 void testDefaultArgsPinvWithOutArg2(int num_args) {
760   Module m("m");
761   if (num_args == 1) {
762     m.define(R"(
763       def forward(self, input):
764         return torch.linalg_pinv(input, out=input)
765     )");
766   } else if (num_args == 2) {
767     m.define(R"(
768       def forward(self, input):
769         return torch.linalg_pinv(input, 1e-5, out=input)
770     )");
771   } else if (num_args == 3) {
772     m.define(R"(
773       def forward(self, input):
774         return torch.linalg_pinv(input, 1e-5, True, out=input)
775     )");
776   }
777 
778   const int N = 28;
779   auto input = torch::range(1, N * N, 1);
780   input[0] = 10000; // a more stable matrix
781   input = input.view({N, N});
782   auto ref = m.run_method("forward", input);
783   TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
784   TORCH_CHECK(input.equal(ref.toTensor()));
785 }
786 
TEST(LiteInterpreterDirectTest,DefaultArgsPinvWithOutArg)787 TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) {
788   // Test with different number of specified arguments + out arg.
789   // Arguments not specified take default value.
790   for (int num_args = 1; num_args <= 3; ++num_args) {
791     testDefaultArgsPinvWithOutArg2(num_args);
792   }
793 }
794 
TEST(LiteInterpreterDirectTest,DefaultArgsWithOutArg)795 TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) {
796   Module m("m");
797   m.define(R"(
798     def forward(self, x, h):
799       torch.add(x, h, out=x)
800   )");
801 
802   std::vector<IValue> inputs;
803   auto input_x = 2 * torch::ones({});
804   auto input_h = torch::ones({});
805   auto ref = m.run_method("forward", input_x, input_h);
806 
807   CompilationOptions options;
808   mobile::Module bc = jitModuleToMobile(m, options);
809   bc.run_method("forward", input_x, input_h);
810   AT_ASSERT(input_x.equal(4 * torch::ones({})));
811 }
812 
TEST(LiteInterpreterDirectTest,TestExceptionStackWithTwoLevelModuleHierarchy)813 TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
814   Module a("A");
815   a.define(R"(
816     def bar(self, x, y):
817       return x + y
818   )");
819   Module b("B");
820   b.register_module("A0", a);
821   b.define(R"(
822     def foo(self, x, y):
823       return self.A0.bar(x, y) + 2
824   )");
825   Module c("C");
826   c.register_module("B0", b);
827   c.define(R"(
828     def forward(self, x, y):
829       return self.B0.foo(x, y) + 3
830   )");
831 
832   std::vector<IValue> inputs;
833   inputs.emplace_back(torch::rand({2, 4}));
834   inputs.emplace_back(torch::rand({13, 9}));
835 
836   CompilationOptions options;
837   auto lite_m = jitModuleToMobile(c, options);
838   std::string error_pattern = R"(
839   Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
840 Traceback of TorchScript (most recent call last):
841   File "<string>", line 3, in <unknown>
842 
843     def forward(self, x, y):
844       return self.B0.foo(x, y) + 3
845              ~~~~~~~~~~~ <--- HERE
846 
847   File "<string>", line 3, in foo
848 
849     def foo(self, x, y):
850       return self.A0.bar(x, y) + 2
851              ~~~~~~~~~~~ <--- HERE
852 
853   File "<string>", line 3, in bar
854 
855     def bar(self, x, y):
856       return x + y
857              ~~~~~ <--- HERE
858   )";
859   ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
860 }
861 #endif // !defined(FB_XPLAT_BUILD)
862 
863 namespace {
864 static auto reg =
865     torch::class_<TorchBindLiteInterpreterDirectTestStruct>(
866         "_TorchScriptTesting",
867         "_LiteInterpreterDirectTest")
868         .def(torch::init<>())
869         .def("get", &TorchBindLiteInterpreterDirectTestStruct::get)
870         .def_pickle(
871             // __getattr__
872             [](const c10::intrusive_ptr<
__anon3fd744610402(const c10::intrusive_ptr< TorchBindLiteInterpreterDirectTestStruct>&) 873                 TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t {
874               return 0;
875             },
876             // __setattr__
__anon3fd744610502(int64_t) 877             [](int64_t) {
878               return c10::make_intrusive<
879                   TorchBindLiteInterpreterDirectTestStruct>();
880             });
881 
882 } // namespace
883 
TEST(LiteInterpreterDirectTest,OperatorCacheDifferentiatesDefaultArgs)884 TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) {
885   // Create 3 methods:
886   //
887   // 1. forward() returns a tensor with dtype=torch.int64 (4)
888   // 2. forward2() returns a tensor with dtype=torch.float32 (6)
889   // 3. forward3() returns a tensor with dtype=torch.float32 but
890   //    the dtype is inferred by the input tensor's dtype
891   //
892   // If caching works correctly, then the result from the full-jit
893   // module and the lite module will be the same. Otherwise, it
894   // will be different if we don't correctly ignore the cache
895   // entry for an operator that has a different number of
896   // arguments.
897   Module m("m");
898   m.define(R"(
899     def forward(self):
900       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
901       return ret1.fill_(25)
902   )");
903   m.define(R"(
904     def forward2(self):
905       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
906       return ret1.fill_(32.0)
907   )");
908   m.define(R"(
909     def forward3(self):
910       ret1 = torch.new_empty(torch.zeros(10), [10])
911       return ret1.fill_(12.0)
912   )");
913 
914   std::vector<torch::jit::IValue> inputs;
915   testLiteModuleCompareResultTensors(m, inputs, "forward");
916   testLiteModuleCompareResultTensors(m, inputs, "forward2");
917   testLiteModuleCompareResultTensors(m, inputs, "forward3");
918 }
919 
920 } // namespace jit
921 } // namespace torch
922