xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_lite_interpreter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/jit/test_utils.h>
2 
3 #include <c10/core/TensorOptions.h>
4 #include <gtest/gtest.h>
5 #include <torch/csrc/autograd/generated/variable_factories.h>
6 #include <torch/csrc/jit/api/module.h>
7 #include <torch/csrc/jit/frontend/resolver.h>
8 #include <torch/csrc/jit/mobile/compatibility/backport.h>
9 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
10 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
11 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
12 #include <torch/csrc/jit/mobile/import.h>
13 #include <torch/csrc/jit/mobile/interpreter.h>
14 #include <torch/csrc/jit/mobile/module.h>
15 #include <torch/csrc/jit/mobile/parse_bytecode.h>
16 #include <torch/csrc/jit/mobile/parse_operators.h>
17 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
18 #include <torch/csrc/jit/serialization/export.h>
19 #include <torch/csrc/jit/serialization/import.h>
20 #include <torch/custom_class.h>
21 #include <torch/torch.h>
22 
23 #include <torch/csrc/jit/serialization/import_export_functions.h>
24 #include <unordered_set>
25 
26 // Tests go in torch::jit
27 namespace torch {
28 namespace jit {
29 
TEST(LiteInterpreterTest,UpsampleNearest2d)30 TEST(LiteInterpreterTest, UpsampleNearest2d) {
31   Module m("m");
32   m.define(R"(
33     def forward(self, input: Tensor, scale:float):
34       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
35   )");
36 
37   std::vector<IValue> inputs;
38   inputs.emplace_back(torch::rand({1, 3, 128, 128}));
39   inputs.emplace_back(at::Scalar(2.0));
40   auto ref = m.forward(inputs);
41 
42   std::stringstream ss;
43   m._save_for_mobile(ss);
44   mobile::Module bc = _load_for_mobile(ss);
45   IValue res;
46   res = bc.forward(inputs);
47 
48   auto resd = res.toTensor();
49   auto refd = ref.toTensor();
50   ASSERT_TRUE(resd.equal(refd));
51 }
52 
TEST(LiteInterpreterTest,CheckAttrAccess)53 TEST(LiteInterpreterTest, CheckAttrAccess) {
54   Module m("m");
55   m.register_attribute("mobile_optimized", BoolType::get(), true);
56 
57   std::stringstream ss;
58   m._save_for_mobile(ss);
59   mobile::Module bc = _load_for_mobile(ss);
60   bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
61 
62   AT_ASSERT(mobile_optimized);
63   m.setattr("mobile_optimized", false);
64   ss = std::stringstream();
65   m._save_for_mobile(ss);
66   bc = _load_for_mobile(ss);
67   mobile_optimized = bc.attr("mobile_optimized", false).toBool();
68 
69   AT_ASSERT(!mobile_optimized);
70 }
71 
TEST(LiteInterpreterTest,MethodInvocation)72 TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
73   const std::vector<std::string> test_programs{
74       // test invoking a method with default parameter
75       R"(
76       def test_func(self, x, b : int = 4):
77         return self.foo + x + b
78       )",
79       // inner method call with default parameter (gets inlined)
80       R"(
81       def add_with_default_arg(self, x, b : int = 4):
82         return self.foo + x + b
83       def test_func(self, x):
84         return self.add_with_default_arg(x)  # invoke method w/ default arg
85       )",
86       // simple method call
87       R"(
88       def test_func(self, x):
89         b = 4
90         return self.foo + x + b
91       )",
92   };
93   for (const auto& test_program : test_programs) {
94     Module m("m");
95     m.register_parameter("foo", torch::ones({}), false);
96     m.define(test_program);
97 
98     const int fortyTwo = 42; // (keep linter happy)
99     auto minput = fortyTwo * torch::ones({});
100     auto ref = m.run_method("test_func", minput);
101 
102     std::stringstream ss;
103     m._save_for_mobile(ss);
104     mobile::Module bc = _load_for_mobile(ss);
105     const auto& test_func = bc.get_method("test_func");
106     IValue res;
107     for (int i = 0; i < 3; ++i) {
108       res = test_func({minput});
109     }
110 
111     auto resd = res.toTensor().item<float>();
112     auto refd = ref.toTensor().item<float>();
113     AT_ASSERT(resd == refd);
114   }
115 }
116 
TEST(LiteInterpreterTest,Conv)117 TEST(LiteInterpreterTest, Conv) {
118   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
119   if (s && strcmp(s, "1") == 0)
120     return;
121 
122   std::vector<torch::jit::IValue> inputs;
123 
124   Module m("m");
125   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
126   m.register_parameter("bias", torch::ones({20}), false);
127   m.define(R"(
128     def forward(self, input):
129       return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
130   )");
131 
132   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
133   inputs.push_back(torch::ones({1, 1, 28, 28}));
134 
135   auto outputref = m.forward(inputs).toTensor();
136 
137   std::stringstream ss;
138   m._save_for_mobile(ss);
139   mobile::Module bc = _load_for_mobile(ss);
140   IValue res;
141   for (int i = 0; i < 3; ++i) {
142     res = bc.get_method("forward")(inputs);
143   }
144   auto output = res.toTensor();
145   AT_ASSERT(outputref.dim() == output.dim());
146   AT_ASSERT(
147       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
148 }
149 
TEST(LiteInterpreterTest,Inline)150 TEST(LiteInterpreterTest, Inline) {
151   Module m("m");
152   m.define(R"JIT(
153   def foo1(self, x):
154       return x + 1
155 
156   def foo2(self, x):
157       return self.foo1(x) + 2
158 
159   def foo3(self, x):
160       return self.foo2(x) + 3
161   )JIT");
162   std::stringstream ss;
163   m._save_for_mobile(ss);
164   mobile::Module bc = _load_for_mobile(ss);
165   std::vector<torch::jit::IValue> inputs({torch::ones({})});
166   auto output = bc.get_method("foo3")(inputs);
167   AT_ASSERT(output.toTensor().item<float>() == 7.0);
168 }
169 
TEST(LiteInterpreterTest,Tuple)170 TEST(LiteInterpreterTest, Tuple) {
171   Module m("m");
172   m.define(R"JIT(
173   def foo(self, x):
174       return (1, 2, x + 3)
175 
176   def forward(self, x):
177       tuple = self.foo(x)
178       return tuple
179   )JIT");
180   std::stringstream ss;
181   m._save_for_mobile(ss);
182   mobile::Module bc = _load_for_mobile(ss);
183   std::vector<torch::jit::IValue> inputs({torch::ones({})});
184   auto output = bc.get_method("forward")(inputs);
185   AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
186 }
187 
TEST(LiteInterpreterTest,AtenFormat)188 TEST(LiteInterpreterTest, AtenFormat) {
189   Module m("m");
190   m.define(R"""(
191   def forward(self, fmt:str="first {} {}", num:str="abc"):
192     x = 2
193     x = x * x
194     return fmt.format(num, x)
195   )""");
196   std::stringstream ss;
197   m._save_for_mobile(ss);
198   mobile::Module bc = _load_for_mobile(ss);
199   std::vector<torch::jit::IValue> inputs;
200   auto output_bc = bc.get_method("forward")(inputs);
201   auto output_m = m.get_method("forward")(inputs);
202   // std::cout << output_m.toStringRef() << "\n"
203   //           << output_bc.toStringRef() << std::endl;
204   AT_ASSERT(output_m.toStringRef() == output_bc.toStringRef());
205 }
206 
TEST(LiteInterpreterTest,PrimDevice)207 TEST(LiteInterpreterTest, PrimDevice) {
208   Module m("m");
209   m.define(R"""(
210   def forward(self, x:torch.Tensor):
211     return x.device
212   )""");
213   std::stringstream ss;
214   m._save_for_mobile(ss);
215   mobile::Module bc = _load_for_mobile(ss);
216   std::vector<torch::jit::IValue> inputs;
217   auto minput = 3.5 * torch::ones({});
218   inputs.emplace_back(minput);
219   auto output_bc = bc.get_method("forward")(inputs);
220   auto output_m = m.get_method("forward")(inputs);
221   AT_ASSERT(output_bc.toDevice().str() == output_m.toDevice().str());
222 }
223 
TEST(LiteInterpreterTest,Dict)224 TEST(LiteInterpreterTest, Dict) {
225   Module m("m");
226   m.define(R"JIT(
227   def foo(self, x):
228       return {"result": x + 1}
229 
230   def forward(self, x):
231       d = self.foo(x)
232       return d
233   )JIT");
234   std::stringstream ss;
235   m._save_for_mobile(ss);
236   mobile::Module bc = _load_for_mobile(ss);
237   std::vector<torch::jit::IValue> inputs({torch::ones({})});
238   auto output = bc.get_method("forward")(inputs);
239   AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
240 }
241 
TEST(LiteInterpreterTest,List)242 TEST(LiteInterpreterTest, List) {
243   Module m("m");
244   m.define(R"JIT(
245   def foo(self, x):
246       return [x + 2]
247 
248   def forward(self, x):
249       d = self.foo(x)
250       return d
251   )JIT");
252   std::stringstream ss;
253   m._save_for_mobile(ss);
254   mobile::Module bc = _load_for_mobile(ss);
255   std::vector<torch::jit::IValue> inputs({torch::ones({})});
256   auto output = bc.get_method("forward")(inputs);
257   auto server_output = m.forward(inputs);
258   EXPECT_EQ(output.toList().get(0).toTensor().item().toInt(), 3);
259   EXPECT_EQ(output, server_output);
260 }
261 
TEST(LiteInterpreterTest,PrimOverload)262 TEST(LiteInterpreterTest, PrimOverload) {
263   /*
264   // temporarily disabled
265   script::Module m("m");
266   m.define(R"JIT(
267   def forward(self, x):
268       result = [1, 2]
269       result.append(3)
270       return result
271   )JIT");
272   std::stringstream ss;
273   m._save_for_mobile(ss);
274   mobile::Module bc = _load_for_mobile(ss);
275   std::vector<torch::jit::IValue> inputs({torch::ones({})});
276   auto output = bc.get_method("forward")(inputs);
277   AT_ASSERT(output.toIntList()[2] == 3);
278   */
279 }
280 
TEST(LiteInterpreterTest,Prim)281 TEST(LiteInterpreterTest, Prim) {
282   Module m("m");
283   m.define(R"JIT(
284         def forward(self, x):
285             return int(x)
286   )JIT");
287 
288   std::vector<IValue> inputs;
289   auto minput = 3.5 * torch::ones({});
290   inputs.emplace_back(minput);
291   auto ref = m.run_method("forward", minput);
292 
293   std::stringstream ss;
294   m._save_for_mobile(ss);
295   mobile::Module bc = _load_for_mobile(ss);
296   IValue res;
297   for (int i = 0; i < 3; ++i) {
298     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
299     auto bcinputs = inputs;
300     res = bc.get_method("forward")(bcinputs);
301   }
302 
303   auto resi = res.toInt();
304   auto refi = ref.toInt();
305   AT_ASSERT(resi == refi);
306 }
307 
TEST(LiteInterpreterTest,PrimScalar)308 TEST(LiteInterpreterTest, PrimScalar) {
309   Module m("m");
310   m.define(R"JIT(
311         def forward(self, x):
312             return int(x.item())
313   )JIT");
314 
315   std::vector<IValue> inputs;
316   auto minput = 3.5 * torch::ones({});
317   inputs.emplace_back(minput);
318   auto ref = m.run_method("forward", minput);
319 
320   std::stringstream ss;
321   m._save_for_mobile(ss);
322   mobile::Module bc = _load_for_mobile(ss);
323   IValue res;
324   for (int i = 0; i < 3; ++i) {
325     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
326     auto bcinputs = inputs;
327     res = bc.get_method("forward")(bcinputs);
328   }
329 
330   auto resi = res.toInt();
331   auto refi = ref.toInt();
332   AT_ASSERT(resi == refi);
333 }
334 
TEST(LiteInterpreterTest,LoadOrigJit)335 TEST(LiteInterpreterTest, LoadOrigJit) {
336   Module m("m");
337   m.register_parameter("foo", torch::ones({}), false);
338   m.define(R"(
339     def forward(self, x):
340       b = 4
341       return self.foo + x + b
342   )");
343   std::stringstream ss;
344   m.save(ss);
345   ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found");
346 }
347 
TEST(LiteInterpreterTest,WrongMethodName)348 TEST(LiteInterpreterTest, WrongMethodName) {
349   Module m("m");
350   m.register_parameter("foo", torch::ones({}), false);
351   m.define(R"(
352     def add(self, x):
353       b = 4
354       return self.foo + x + b
355   )");
356   std::stringstream ss;
357   m._save_for_mobile(ss);
358   mobile::Module bc = _load_for_mobile(ss);
359   std::vector<IValue> inputs;
360   auto minput = 5 * torch::ones({});
361   inputs.emplace_back(minput);
362   ASSERT_THROWS_WITH_MESSAGE(
363       bc.get_method("forward")(inputs), "is not defined");
364 }
365 
TEST(LiteInterpreterTest,SetState)366 TEST(LiteInterpreterTest, SetState) {
367   Module m("m");
368   m.register_parameter("foo", torch::ones({}), false);
369   m.define(R"(
370     def __getstate__(self):
371       return self.foo + self.foo
372     def __setstate__(self, a):
373       self.foo = a
374     def forward(self, x):
375       b = 4
376       return self.foo + x + b
377   )");
378 
379   std::vector<IValue> inputs;
380   auto minput = 5 * torch::ones({});
381   inputs.emplace_back(minput);
382 
383   std::stringstream ms;
384   m.save(ms);
385   auto loaded_m = load(ms);
386   auto ref = loaded_m.run_method("forward", minput);
387 
388   std::stringstream ss;
389   m._save_for_mobile(ss);
390   mobile::Module bc = _load_for_mobile(ss);
391   IValue res;
392   for (int i = 0; i < 3; ++i) {
393     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
394     auto bcinputs = inputs;
395     res = bc.get_method("forward")(bcinputs);
396   }
397 
398   auto resd = res.toTensor().item<float>();
399   auto refd = ref.toTensor().item<float>();
400   AT_ASSERT(resd == refd);
401 }
402 
403 class TorchBindLiteInterpreterTestStruct
404     : public torch::jit::CustomClassHolder {
405  public:
get(at::Tensor t)406   std::string get(at::Tensor t) {
407     std::stringstream ss;
408     ss << "Hello! Your tensor has ";
409     ss << t.numel();
410     ss << " elements!";
411     return ss.str();
412   }
413 };
414 
415 namespace {
416 struct ClassNamespaceValue : public SugaredValue {
ClassNamespaceValuetorch::jit::__anond2aa1aa70111::ClassNamespaceValue417   explicit ClassNamespaceValue(c10::QualifiedName name)
418       : basename_(std::move(name)) {}
419 
attrtorch::jit::__anond2aa1aa70111::ClassNamespaceValue420   std::shared_ptr<SugaredValue> attr(
421       const SourceRange& loc,
422       GraphFunction& m,
423       const std::string& name) override {
424     const auto fullName = c10::QualifiedName(basename_, name);
425 
426     // Check to see if it is a custom class.
427     if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
428       return std::make_shared<ClassValue>(custom_class);
429     }
430 
431     // If it's not a custom class, assume it's another namespace
432     // NOLINTNEXTLINE(performance-move-const-arg)
433     return std::make_shared<ClassNamespaceValue>(std::move(fullName));
434   }
435 
kindtorch::jit::__anond2aa1aa70111::ClassNamespaceValue436   std::string kind() const override {
437     return "Class Namespace";
438   }
439 
440  private:
441   c10::QualifiedName basename_;
442 };
443 
444 struct TestModuleResolver : public Resolver {
resolveValuetorch::jit::__anond2aa1aa70111::TestModuleResolver445   std::shared_ptr<SugaredValue> resolveValue(
446       const std::string& name,
447       GraphFunction& m,
448       const SourceRange& loc) override {
449     if (name == "torch") {
450       return std::make_shared<BuiltinModule>("aten");
451     } else if (name == "__torch__") {
452       return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
453     }
454 
455     return nullptr;
456   }
457 
resolveTypetorch::jit::__anond2aa1aa70111::TestModuleResolver458   TypePtr resolveType(const std::string& name, const SourceRange& loc)
459       override {
460     return nullptr;
461   }
462 };
463 } // namespace
464 
TEST(LiteInterpreterTest,BuiltinClass)465 TEST(LiteInterpreterTest, BuiltinClass) {
466   script::Module m("m");
467 
468   auto cls = getCustomClass(
469       "__torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest");
470   TORCH_INTERNAL_ASSERT(cls);
471   c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
472   m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
473 
474   m.register_parameter("foo", torch::ones({}), false);
475   m.define(
476       R"(
477     def __getstate__(self):
478       return 1
479     def __setstate__(self, a):
480       self.my_obj = __torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest()
481 
482     def forward(self, x) -> str:
483       return self.my_obj.get(x)
484   )",
485       std::make_shared<TestModuleResolver>());
486 
487   std::stringstream ss;
488   m._save_for_mobile(ss);
489   mobile::Module bc = _load_for_mobile(ss);
490   auto res =
491       bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
492   const auto& str = res.toStringRef();
493   std::string expected = "Hello! Your tensor has 12 elements!";
494   AT_ASSERT(str == expected);
495 }
496 
TEST(LiteInterpreterTest,BuiltinFunction)497 TEST(LiteInterpreterTest, BuiltinFunction) {
498   script::Module m("m");
499   auto custom_class_obj =
500       make_custom_class<TorchBindLiteInterpreterTestStruct>();
501   m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
502   m.define(R"(
503     def forward(self, x) -> str:
504       return self.my_obj.get(x)
505   )");
506 
507   std::stringstream ss;
508   m._save_for_mobile(ss);
509   mobile::Module bc = _load_for_mobile(ss);
510   auto res =
511       bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
512   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
513   auto str = res.toStringRef();
514   std::string expected = "Hello! Your tensor has 12 elements!";
515   AT_ASSERT(str == expected);
516 }
517 
518 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest,GetRuntimeByteCodeVersion)519 TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
520   auto runtime_bytecode_version = _get_runtime_bytecode_version();
521   AT_ASSERT(
522       runtime_bytecode_version ==
523       caffe2::serialize::kMaxSupportedBytecodeVersion);
524 }
525 
TEST(LiteInterpreterTest,GetRuntimeOperatorsVersion)526 TEST(LiteInterpreterTest, GetRuntimeOperatorsVersion) {
527   auto runtime_operators_version = _get_runtime_operators_min_max_versions();
528   AT_ASSERT(
529       runtime_operators_version.first ==
530           caffe2::serialize::kMinSupportedFileFormatVersion &&
531       runtime_operators_version.second ==
532           caffe2::serialize::kMaxSupportedFileFormatVersion);
533 }
534 
535 /**
536  * The test below is disarmed for FB internal xplat builds since
537  * BUCK requires us to pass in the script_module_v4.ptl file in
538  * as a resource dependency of the build rule for this file, and
539  * we would need to access it via the C++ Resources API instead
540  * of directly reading from disk (which is what the open source
541  * build/run does).
542  */
TEST(LiteInterpreterTest,GetByteCodeVersion)543 TEST(LiteInterpreterTest, GetByteCodeVersion) {
544   std::string filePath(__FILE__);
545   auto test_model_file_v4 =
546       filePath.substr(0, filePath.find_last_of("/\\") + 1);
547   test_model_file_v4.append("script_module_v4.ptl");
548 
549   auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
550   AT_ASSERT(version_v4 == 4);
551 }
552 
553 #endif // !defined(FB_XPLAT_BUILD)
554 
TEST(LiteInterpreterTest,GetContainTypes)555 TEST(LiteInterpreterTest, GetContainTypes) {
556   Module m("m");
557   m.define(R"(
558     def forward(self):
559       return 3
560   )");
561 
562   std::stringstream ss;
563   m._save_for_mobile(ss, {}, true);
564 
565   _get_mobile_model_contained_types(ss);
566 }
567 
568 namespace {
569 
compareModelOutput(c10::ArrayRef<IValue> actual_result_list,const std::vector<IValue> & expect_result_list)570 void compareModelOutput(
571     c10::ArrayRef<IValue> actual_result_list,
572     const std::vector<IValue>& expect_result_list) {
573   AT_ASSERT(actual_result_list.size() == expect_result_list.size());
574   AT_ASSERT(
575       actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
576   AT_ASSERT(
577       actual_result_list[1].toTensor().dim() ==
578       expect_result_list[1].toTensor().dim());
579   AT_ASSERT(
580       actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
581   AT_ASSERT(
582       actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
583   ASSERT_EQ(
584       actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
585   ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
586   ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
587   ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
588   AT_ASSERT(
589       actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
590   ASSERT_EQ(
591       actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
592   ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
593   ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
594 }
595 
runAndCheckTorchScriptModel(std::stringstream & input_model_stream,const std::vector<IValue> & input_data,const std::vector<IValue> & expect_result_list,const uint64_t expect_version)596 void runAndCheckTorchScriptModel(
597     std::stringstream& input_model_stream,
598     const std::vector<IValue>& input_data,
599     const std::vector<IValue>& expect_result_list,
600     const uint64_t expect_version) {
601   auto actual_version = _get_model_bytecode_version(input_model_stream);
602   AT_ASSERT(actual_version == expect_version);
603 
604   // Load and run the backport model, then compare the result with expect
605   // result
606   Module m_mobile = load(input_model_stream);
607 
608   auto actual_result = m_mobile.forward(input_data);
609   const auto& actual_result_list = actual_result.toTupleRef().elements();
610   compareModelOutput(actual_result_list, expect_result_list);
611 }
612 
runAndCheckBytecodeModel(std::stringstream & input_model_stream,const std::vector<IValue> & input_data,const std::vector<IValue> & expect_result_list,const uint64_t expect_version)613 void runAndCheckBytecodeModel(
614     std::stringstream& input_model_stream,
615     const std::vector<IValue>& input_data,
616     const std::vector<IValue>& expect_result_list,
617     const uint64_t expect_version) {
618   auto actual_version = _get_model_bytecode_version(input_model_stream);
619   AT_ASSERT(actual_version == expect_version);
620 
621   // Load and run the backport model, then compare the result with expect
622   // result
623   Module m_mobile = load(input_model_stream);
624 
625   auto actual_result = m_mobile.forward(input_data);
626   const auto& actual_result_list = actual_result.toTupleRef().elements();
627 
628   compareModelOutput(actual_result_list, expect_result_list);
629 }
630 
backportAllVersionCheck(std::stringstream & test_model_file_stream,std::vector<IValue> & input_data,std::vector<IValue> & expect_result_list,const uint64_t expect_from_version)631 void backportAllVersionCheck(
632     std::stringstream& test_model_file_stream,
633     std::vector<IValue>& input_data,
634     std::vector<IValue>& expect_result_list,
635     const uint64_t expect_from_version) {
636   auto from_version = _get_model_bytecode_version(test_model_file_stream);
637   EXPECT_EQ(from_version, expect_from_version);
638   AT_ASSERT(from_version > 0);
639 
640   // Backport script_module_v5.ptl to an older version
641   constexpr int64_t minimum_to_version = 4;
642   auto current_to_version = from_version - 1;
643 
644   // Verify all candidate to_version work as expected. All backport to version
645   // larger than minimum_to_version should success.
646   while (current_to_version >= minimum_to_version) {
647     // Do not declare std::stringstream oss outside of the while loop as
648     // oss.clear() doesn't reset the stream content, only clears out error state
649     // flag in stringstream causing a problematic stream. Instead, it's cleaner
650     // and safer to just declare a new std::stringstream one and swap them.
651     std::stringstream oss;
652     bool backPortSuccess =
653         _backport_for_mobile(test_model_file_stream, oss, current_to_version);
654     AT_ASSERT(backPortSuccess);
655 
656     // Check backport model version
657     auto backport_version = _get_model_bytecode_version(oss);
658     backport_version = _get_model_bytecode_version(oss);
659     AT_ASSERT(backport_version == current_to_version);
660 
661     // Load and run the backport model, then compare the result with expect
662     // result
663     runAndCheckBytecodeModel(
664         oss, input_data, expect_result_list, current_to_version);
665     oss.seekg(0, oss.beg);
666     runAndCheckTorchScriptModel(
667         oss, input_data, expect_result_list, current_to_version);
668 
669     current_to_version--;
670   }
671   //  backport to minimum version - 1 should fail
672   std::stringstream oss;
673   bool backPortSuccess =
674       _backport_for_mobile(test_model_file_stream, oss, minimum_to_version - 1);
675   AT_ASSERT(!backPortSuccess);
676 }
677 } // namespace
678 
679 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest,BackPortByteCodeModelAllVersions)680 TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
681   torch::jit::Module module("m");
682   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
683   module.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
684   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
685   module.register_parameter("bias", torch::ones({20}), false);
686   module.define(R"(
687     def fn(self, x:float=1.0):
688       return x
689 
690     def forward(self, input):
691       x1 = torch.zeros(2, 2)
692       x2 = torch.empty_like(torch.empty(2, 2))
693       x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
694       # Add torch.add operator to cover bytecode version bump from 6 to 7
695       # for bytecode version 7, the main change is to support defaults arguments with out arguments
696       x = 2 * torch.ones(1)
697       h = torch.ones(1)
698       torch.add(x, h, out=x)
699       device = torch.ones(1, 1).cpu().device.type
700       is_cuda = x1.is_cuda
701       bool_val = True
702       check_is = [] is None
703       check_is_not = [1] is not None
704       check_not = not bool_val
705       num_to_tensor = torch.tensor([self.fn()])
706       d = {"a": "abc"}
707       check_dict_index = d["a"]
708       check_dim = x1.dim()
709       return (
710         x1, x2, x3, x, device, is_cuda, check_is,
711         check_is_not, num_to_tensor, check_dict_index,
712         check_dim, check_not
713         )
714       )");
715 
716   torch::jit::Module module_freeze = freeze(module);
717 
718   std::stringstream input_model_stream;
719   module_freeze._save_for_mobile(
720       input_model_stream,
721       /*extra_files=*/{},
722       /*save_mobile_debug_info=*/false,
723       /*use_flatbuffer=*/true);
724   std::vector<IValue> input_data =
725       std::vector<IValue>({torch::ones({1, 1, 28, 28})});
726   std::vector<IValue> expect_result_list;
727   expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
728   expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
729   expect_result_list.emplace_back(
730       at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
731   expect_result_list.emplace_back(3 * at::ones({1}));
732   // "cpu" False, False, True, tensor(1), "abc", 2, False)
733   expect_result_list.emplace_back(c10::IValue("cpu"));
734   expect_result_list.emplace_back(c10::IValue(false));
735   expect_result_list.emplace_back(c10::IValue(false));
736   expect_result_list.emplace_back(c10::IValue(true));
737   expect_result_list.emplace_back(c10::IValue(at::ones({1})));
738   expect_result_list.emplace_back(c10::IValue("abc"));
739   expect_result_list.emplace_back(c10::IValue(2));
740   expect_result_list.emplace_back(c10::IValue(false));
741 
742   backportAllVersionCheck(
743       input_model_stream,
744       input_data,
745       expect_result_list,
746       9); // flatbuffer starts at 9
747 }
748 #endif // !defined(FB_XPLAT_BUILD)
749 
TEST(LiteInterpreterTest,GetRuntimeOpsAndInfo)750 TEST(LiteInterpreterTest, GetRuntimeOpsAndInfo) {
751   auto runtime_ops = _get_runtime_ops_and_info();
752   // Ballpark estimate of the minimal number of ops; just used to
753   // verify API returns a reasonably large number.
754   AT_ASSERT(runtime_ops.size() > 2900);
755 }
756 
TEST(LiteInterpreterTest,isCompatibleSuccess)757 TEST(LiteInterpreterTest, isCompatibleSuccess) {
758   // test trivial success case
759   auto runtime_info = RuntimeCompatibilityInfo::get();
760   std::unordered_map<std::string, OperatorInfo> model_ops;
761   model_ops["aten::add.Scalar"] = OperatorInfo{2};
762 
763   std::unordered_set<std::string> types = {"List", "int", "NamedTuple"};
764   auto model_info = ModelCompatibilityInfo{
765       caffe2::serialize::kMaxSupportedBytecodeVersion,
766       model_ops,
767       types,
768       _get_runtime_bytecode_min_max_versions().first};
769 
770   AT_ASSERT(
771       is_compatible(runtime_info, model_info).status ==
772       ModelCompatibilityStatus::OK);
773 }
774 
TEST(LiteInterpreterTest,isCompatibleFail)775 TEST(LiteInterpreterTest, isCompatibleFail) {
776   // test trivial failure due to ops
777   std::unordered_map<std::string, OperatorInfo> model_ops;
778   model_ops["aten::add.Scalar"] = OperatorInfo{2};
779   auto model_info = ModelCompatibilityInfo{
780       caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops};
781   std::unordered_map<std::string, OperatorInfo> runtime_ops;
782   runtime_ops["aten::add.Int"] = OperatorInfo{2};
783   auto runtime_info = RuntimeCompatibilityInfo{
784       std::pair<uint64_t, uint64_t>(
785           caffe2::serialize::kMinSupportedBytecodeVersion,
786           caffe2::serialize::kMaxSupportedBytecodeVersion),
787       runtime_ops,
788       _get_mobile_supported_types()};
789 
790   auto result = is_compatible(runtime_info, model_info);
791   AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
792   AT_ASSERT(
793       result.errors[0] ==
794       "Operator 'aten::add.Scalar' missing from runtime (not found)");
795 
796   // test trivial failure due to bytecode greater than max supported bytecode
797   // version
798   runtime_ops["aten::add.Scalar"] = OperatorInfo{2};
799   runtime_info = RuntimeCompatibilityInfo{
800       std::pair<uint64_t, uint64_t>(
801           caffe2::serialize::kMinSupportedBytecodeVersion,
802           caffe2::serialize::kMaxSupportedBytecodeVersion),
803       runtime_ops,
804       _get_mobile_supported_types()};
805   model_info.bytecode_version =
806       caffe2::serialize::kMaxSupportedBytecodeVersion + 1;
807 
808   result = is_compatible(runtime_info, model_info);
809   AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
810 
811   // test trivial failure due to bytecode less than min supported bytecode
812   // version
813   runtime_ops["aten::add.Scalar"] = OperatorInfo{2};
814   runtime_info = RuntimeCompatibilityInfo{
815       std::pair<uint64_t, uint64_t>(
816           caffe2::serialize::kMinSupportedBytecodeVersion,
817           caffe2::serialize::kMaxSupportedBytecodeVersion),
818       runtime_ops,
819       _get_mobile_supported_types()};
820   model_info.bytecode_version =
821       caffe2::serialize::kMinSupportedBytecodeVersion - 1;
822 
823   result = is_compatible(runtime_info, model_info);
824   AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
825 
826   // test trivial failure due to type
827   runtime_info = RuntimeCompatibilityInfo::get();
828   std::unordered_set<std::string> types = {"List", "int", "Sequence"};
829 
830   model_info = ModelCompatibilityInfo{
831       caffe2::serialize::kMaxSupportedBytecodeVersion,
832       model_ops,
833       types,
834       _get_runtime_bytecode_min_max_versions().first};
835 
836   AT_ASSERT(
837       is_compatible(runtime_info, model_info).status ==
838       ModelCompatibilityStatus::ERROR);
839 
840   // test trivial failure due to operator version
841   runtime_info = RuntimeCompatibilityInfo::get();
842 
843   model_info = ModelCompatibilityInfo{
844       caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, {}, 0};
845 
846   AT_ASSERT(
847       is_compatible(runtime_info, model_info).status ==
848       ModelCompatibilityStatus::ERROR);
849 }
850 
TEST(LiteInterpreterTest,Eval)851 TEST(LiteInterpreterTest, Eval) {
852   std::vector<torch::jit::IValue> inputs;
853 
854   Module m("m");
855   m.define(R"(
856     def __init__(self, x):
857       self.training = True
858 
859     def forward(self, input):
860       return torch.dropout(input, 1.0, self.training)
861   )");
862 
863   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
864   inputs.push_back(torch::ones({1, 1, 28, 28}));
865   m.eval();
866   auto outputref = m.forward(inputs).toTensor();
867 
868   // save m in training mode to make sure that mobile eval() will correctly
869   // change back to eval mode
870   m.train();
871   std::stringstream ss;
872   m._save_for_mobile(ss);
873   mobile::Module bc = _load_for_mobile(ss);
874   bc.eval();
875   IValue res;
876   for (int i = 0; i < 3; ++i) {
877     res = bc.get_method("forward")(inputs);
878   }
879   auto output = res.toTensor();
880   AT_ASSERT(outputref.dim() == output.dim());
881   AT_ASSERT(
882       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
883 }
884 
TEST(LiteInterpreterTest,FindWrongMethodName)885 TEST(LiteInterpreterTest, FindWrongMethodName) {
886   Module m("m");
887   m.register_parameter("foo", torch::ones({}), false);
888   m.define(R"(
889     def add(self, x):
890       b = 4
891       return self.foo + x + b
892   )");
893   std::stringstream ss;
894   m._save_for_mobile(ss);
895   mobile::Module bc = _load_for_mobile(ss);
896   ASSERT_TRUE(bc.find_method("forward") == std::nullopt);
897 }
898 
TEST(LiteInterpreterTest,FindAndRunMethod)899 TEST(LiteInterpreterTest, FindAndRunMethod) {
900   Module m("m");
901   m.register_parameter("foo", torch::ones({}), false);
902   m.define(R"(
903     def add_it(self, x):
904       b = 4
905       return self.foo + x + b
906   )");
907 
908   std::vector<IValue> inputs;
909   auto minput = 5 * torch::ones({});
910   inputs.emplace_back(minput);
911   auto ref = m.get_method("add_it")(inputs);
912 
913   std::stringstream ss;
914   m._save_for_mobile(ss);
915   mobile::Module bc = _load_for_mobile(ss);
916   IValue res;
917   for (int i = 0; i < 3; ++i) {
918     auto bcinputs = inputs;
919     auto method = bc.find_method("add_it");
920     AT_ASSERT(method != std::nullopt);
921     res = (*method)(std::move(bcinputs));
922   }
923 
924   auto resd = res.toTensor().item<float>();
925   auto refd = ref.toTensor().item<float>();
926   AT_ASSERT(resd == refd);
927 }
928 
TEST(LiteInterpreterTest,RunMethodVariadic)929 TEST(LiteInterpreterTest, RunMethodVariadic) {
930   Module m("m");
931   m.register_parameter("foo", torch::ones({}), false);
932   m.define(R"(
933     def add_three(self, x, y):
934       return self.foo + x + y
935   )");
936 
937   std::vector<IValue> inputs;
938   auto inputx = 5 * torch::ones({});
939   auto inputy = 4 * torch::ones({});
940   auto ref = m.run_method("add_three", inputx, inputy);
941 
942   std::stringstream ss;
943   m._save_for_mobile(ss);
944   mobile::Module bc = _load_for_mobile(ss);
945   IValue res = bc.run_method("add_three", inputx, inputy);
946 
947   auto resd = res.toTensor().item<float>();
948   auto refd = ref.toTensor().item<float>();
949   AT_ASSERT(resd == refd);
950 }
951 
TEST(LiteInterpreterTest,DuplicateSetState)952 TEST(LiteInterpreterTest, DuplicateSetState) {
953   Module m("M");
954   m.register_parameter("foo", torch::ones({}), false);
955   m.define(R"(
956     def __getstate__(self):
957       return self.foo + self.foo
958     def __setstate__(self, a):
959       self.foo = a
960     def forward(self, x):
961       b = 4
962       return self.foo + x + b
963   )");
964 
965   Module b("B");
966   b.register_module("M0", m);
967   b.register_module("M1", m);
968   b.define(R"(
969     def forward(self, x):
970       return self.M0.forward(x) + self.M1.forward(x)
971   )");
972 
973   std::stringstream ss;
974   m._save_for_mobile(ss);
975   mobile::Module bc = _load_for_mobile(ss);
976   const auto methods = bc.get_methods();
977   const size_t expected_n = 3;
978   ASSERT_EQ(methods.size(), expected_n);
979 }
980 
TEST(LiteInterpreterTest,ExtraFiles)981 TEST(LiteInterpreterTest, ExtraFiles) {
982   const auto script = R"JIT(
983     def forward(self):
984         x = torch.rand(5, 5)
985         x = x.mm(x)
986         return x
987   )JIT";
988 
989   auto module =
990       std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
991   module->define(script);
992   std::ostringstream oss;
993   std::unordered_map<std::string, std::string> extra_files;
994   extra_files["metadata.json"] = "abc";
995   extra_files["mobile_info.json"] = "{\"key\": 23}";
996   module->_save_for_mobile(oss, extra_files);
997 
998   std::istringstream iss(oss.str());
999   std::unordered_map<std::string, std::string> loaded_extra_files;
1000   loaded_extra_files["metadata.json"] = "";
1001   torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
1002   ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
1003 
1004   loaded_extra_files.clear();
1005   std::vector<std::string> all_files =
1006       caffe2::serialize::PyTorchStreamReader(&iss).getAllRecords();
1007 
1008   for (auto& file_name : all_files) {
1009     if (file_name.find("extra/") == 0) {
1010       loaded_extra_files[file_name.substr(6)] = "";
1011     }
1012   }
1013   iss.seekg(0, iss.beg);
1014   torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
1015   ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
1016   ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
1017 
1018   std::unordered_map<std::string, std::string>
1019       loaded_extra_files_without_explicit_mapping;
1020   iss.seekg(0, iss.beg);
1021   torch::jit::_load_for_mobile(
1022       iss,
1023       torch::kCPU,
1024       loaded_extra_files_without_explicit_mapping,
1025       MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS);
1026   ASSERT_EQ(
1027       loaded_extra_files_without_explicit_mapping["metadata.json"], "abc");
1028   ASSERT_EQ(
1029       loaded_extra_files_without_explicit_mapping["mobile_info.json"],
1030       "{\"key\": 23}");
1031 }
1032 
TEST(LiteInterpreterTest,OpNameExportFetchRootOperators)1033 TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
1034   torch::jit::Module m("m");
1035   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
1036   m.register_parameter("bias", torch::ones({20}), false);
1037   m.define(R"(
1038     def forward(self, input):
1039       x1 = torch.zeros(2, 2)
1040       x2 = torch.empty_like(torch.empty(2, 2))
1041       x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
1042       return (x1, x2, x3)
1043   )");
1044   m.eval();
1045 
1046   std::stringstream ss;
1047   m._save_for_mobile(ss);
1048 
1049   torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss);
1050   std::set<std::string> operator_names =
1051       torch::jit::mobile::_export_operator_list(ptl_model);
1052   std::set<std::string> expected_operator_names = {
1053       "aten::_convolution",
1054       "aten::empty.memory_format",
1055       "aten::empty_like",
1056       "aten::zeros",
1057   };
1058   EXPECT_EQ(operator_names, expected_operator_names)
1059       << "Expected the root operator lists to be the same";
1060 }
1061 
TEST(LiteInterpreterTest,DefaultArgsConv)1062 TEST(LiteInterpreterTest, DefaultArgsConv) {
1063   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
1064   if (s && strcmp(s, "1") == 0)
1065     return;
1066 
1067   std::vector<torch::jit::IValue> inputs;
1068 
1069   Module m("m");
1070   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
1071   m.register_parameter("bias", torch::ones({20}), false);
1072   m.define(R"(
1073     def forward(self, input):
1074       return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
1075   )");
1076 
1077   inputs.push_back(torch::ones({1, 1, 28, 28}));
1078 
1079   auto outputref = m.forward(inputs).toTensor();
1080 
1081   std::stringstream ss;
1082   m._save_for_mobile(ss);
1083   mobile::Module bc = _load_for_mobile(ss);
1084   IValue res;
1085   for (int i = 0; i < 1; ++i) {
1086     res = bc.get_method("forward")(inputs);
1087   }
1088   auto output = res.toTensor();
1089   AT_ASSERT(outputref.dim() == output.dim());
1090   AT_ASSERT(output.equal(outputref));
1091 }
1092 
TEST(RunTimeTest,ParseBytecode)1093 TEST(RunTimeTest, ParseBytecode) {
1094   // A simple example to show a simple bytecode that can be used independent of
1095   // PyTorch TorchScript serialization (unpickler, etc) and operator library.
1096   // It has basic control flow (if, else) and basic data orchestration (list
1097   // construction). The original PyTorch program:
1098 
1099   //  class Module(torch.nn.Module):
1100   //
1101   //    def __init__(self) -> None:
1102   //      super().__init__()
1103   //
1104   //    def forward(self, x: int, h: int, xfirst: bool):
1105   //      if xfirst:
1106   //        return [x, h]
1107   //      else:
1108   //        return [h, x]
1109 
1110   // 1. Prepare for the bytecode. In reality it can be from a customized
1111   // deserializer.
1112   std::vector<IValue> instructions{
1113       to_tuple({"STOREN", 1, 4}),
1114       to_tuple({"DROPR", 1, 0}),
1115       to_tuple({"MOVE", 4, 0}),
1116       to_tuple({"JF", 5, 0}),
1117       to_tuple({"LOAD", 2, 0}),
1118       to_tuple({"LOAD", 3, 0}),
1119       to_tuple({"LIST_CONSTRUCT", 0, 2}),
1120       to_tuple({"JMP", 4, 0}),
1121       to_tuple({"LOAD", 3, 0}),
1122       to_tuple({"LOAD", 2, 0}),
1123       to_tuple({"LIST_CONSTRUCT", 1, 2}),
1124       to_tuple({"STORE", 5, 0}),
1125       to_tuple({"DROPR", 3, 0}),
1126       to_tuple({"DROPR", 2, 0}),
1127       to_tuple({"MOVE", 5, 0}),
1128       to_tuple({"RET", 0, 0}),
1129   };
1130   std::vector<IValue> operators; // empty for this example
1131   std::vector<IValue> constants; // empty for this example
1132 
1133   std::vector<IValue> types{"List[int]", "List[int]"};
1134   // 2. Parse the function
1135   std::string function_name("test_function");
1136   auto function = std::unique_ptr<mobile::Function>(
1137       new mobile::Function(c10::QualifiedName(function_name)));
1138   c10::ivalue::TupleElements debug_handles_m_tuple;
1139   parseInstructions(
1140       function_name,
1141       std::move(*c10::ivalue::Tuple::create(instructions)).elements(),
1142       debug_handles_m_tuple,
1143       function.get());
1144   parseTypes(c10::ivalue::Tuple::create(types)->elements(), function.get());
1145   const size_t rsize = 5;
1146   parseRegisterSize(rsize, function.get());
1147 
1148   // 3. Prepare for inputs and run the function
1149   // Note that the first input is reserved for Module object.
1150   // Since this is a function test and Module object is not required,
1151   // a dummy IValue (0) is added here.
1152   std::vector<IValue> inputs{0, 1, 2, true};
1153   function->run(inputs);
1154   auto output = inputs[0].toList();
1155   ASSERT_EQ(output[0], 1);
1156   ASSERT_EQ(output[1], 2);
1157 
1158   std::vector<IValue> inputs1{0, 1, 2, false};
1159   function->run(inputs1);
1160   auto output1 = inputs1[0].toList();
1161   ASSERT_EQ(output1[0], 2);
1162   ASSERT_EQ(output1[1], 1);
1163 }
1164 
TEST(RunTimeTest,ParseOperator)1165 TEST(RunTimeTest, ParseOperator) {
1166   // A simple example to show a simple bytecode that can be used independent of
1167   // PyTorch TorchScript serialization (unpickler, etc) and operator library.
1168   // It has one operator and we should be able to register it. The original
1169   // PyTorch program:
1170 
1171   // class Add(torch.nn.Module):
1172   //     def __init__(self) -> None:
1173   //         super().__init__()
1174 
1175   //     def forward(self, a, b):
1176   //         return a + b
1177 
1178   // 1. Prepare for the bytecode. In reality it can be from a customized
1179   // deserializer.
1180   std::vector<IValue> instructions{
1181       to_tuple({"STOREN", 1, 3}),
1182       to_tuple({"DROPR", 1, 0}),
1183       to_tuple({"MOVE", 2, 0}),
1184       to_tuple({"MOVE", 3, 0}),
1185       to_tuple({"OP", 0, 0}),
1186       to_tuple({"RET", 0, 0}),
1187   };
1188   std::vector<IValue> operators{
1189       to_tuple({"aten::add", "Tensor", 2}),
1190   };
1191   std::vector<IValue> constants{
1192       to_tuple({1}),
1193   };
1194   // 2. Parse the function
1195   std::string function_name("test_function");
1196   auto function = std::unique_ptr<mobile::Function>(
1197       new mobile::Function(c10::QualifiedName(function_name)));
1198   c10::ivalue::TupleElements debug_handles_m_tuple;
1199   parseInstructions(
1200       function_name,
1201       std::move(*c10::ivalue::Tuple::create(instructions)).elements(),
1202       debug_handles_m_tuple,
1203       function.get());
1204   parseOperators(
1205       std::move(*c10::ivalue::Tuple::create(operators)).elements(),
1206       1,
1207       function.get());
1208   const size_t rsize = 5;
1209   parseRegisterSize(rsize, function.get());
1210 
1211   // 3. Prepare for inputs and run the function
1212   // Note that the first input is reserved for Module object.
1213   // Since this is a function test and Module object is not required,
1214   // a dummy IValue (0) is added here.
1215   std::vector<IValue> inputs{0, at::tensor(1), at::tensor(2)};
1216   function->run(inputs);
1217   auto output = inputs[0];
1218   ASSERT_EQ(output, at::tensor(3));
1219 }
1220 
1221 namespace {
testLiteModuleCompareResultTensors(Module & m,const std::vector<torch::jit::IValue> & inputs,const std::string & method_name="forward")1222 void testLiteModuleCompareResultTensors(
1223     Module& m,
1224     const std::vector<torch::jit::IValue>& inputs,
1225     const std::string& method_name = "forward") {
1226   auto outputref = m.get_method(method_name)(inputs).toTensor();
1227 
1228   std::stringstream ss;
1229   m._save_for_mobile(ss);
1230   mobile::Module bc = _load_for_mobile(ss);
1231   IValue res;
1232   for (int i = 0; i < 3; ++i) {
1233     res = bc.get_method(method_name)(inputs);
1234   }
1235   auto output = res.toTensor();
1236   AT_ASSERT(outputref.dim() == output.dim());
1237   AT_ASSERT(output.equal(outputref));
1238 }
1239 
testDefaultArgsPinv(int num_args)1240 void testDefaultArgsPinv(int num_args) {
1241   Module m("m");
1242   if (num_args == 1) {
1243     m.define(R"(
1244       def forward(self, input):
1245         return torch.linalg_pinv(input)
1246     )");
1247   } else if (num_args == 2) {
1248     m.define(R"(
1249       def forward(self, input):
1250         return torch.linalg_pinv(input, 1e-5)
1251     )");
1252   } else if (num_args == 3) {
1253     m.define(R"(
1254       def forward(self, input):
1255         return torch.linalg_pinv(input, 1e-5, True)
1256     )");
1257   }
1258 
1259   std::vector<torch::jit::IValue> inputs;
1260   const int N = 28;
1261   auto input = torch::range(1, N * N, 1);
1262   input[0] = 1; // a more stable matrix
1263   input = input.view({N, N});
1264   inputs.push_back(input);
1265   testLiteModuleCompareResultTensors(m, inputs);
1266 }
1267 } // namespace
1268 
1269 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest,DefaultArgsPinv)1270 TEST(LiteInterpreterTest, DefaultArgsPinv) {
1271   // Test with different number of specified arguments.
1272   // Arguments not specified take default value.
1273   for (int num_args = 1; num_args <= 3; ++num_args) {
1274     testDefaultArgsPinv(num_args);
1275   }
1276 
1277   //  bytecode with one specified argument:
1278   //  (6,
1279   //      ('__torch__.m.forward',
1280   //          (('instructions',
1281   //              (('STOREN', 1, 2),
1282   //                  ('DROPR', 1, 0),
1283   //                  ('MOVE', 2, 0),
1284   //                  ('OP', 0, 0),
1285   //                  ('RET', 0, 0))),
1286   //              ('operators', (('aten::linalg_pinv', '', 1),)),
1287   //              ('constants', (False, 1e-15)), # default constants are not
1288   //              used
1289   //              ('types', ()),
1290   //              ('register_size', 2)),
1291   //          (('arguments',
1292   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1293   //              None)),
1294   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
1295   //                  None)))),
1296   //              ('returns',
1297   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
1298   //                  None)),)))))
1299 
1300   //  bytecode with 2 specified argument:
1301   //  (6,
1302   //      ('__torch__.m.forward',
1303   //          (('instructions',
1304   //              (('STOREN', 1, 2),
1305   //                  ('DROPR', 1, 0),
1306   //                  ('MOVE', 2, 0),
1307   //                  ('LOADC', 1, 0), # added LOADC for specified argument
1308   //                  ('OP', 0, 0),
1309   //                  ('RET', 0, 0))),
1310   //              ('operators', (('aten::linalg_pinv', '', 2),)),
1311   //              ('constants', (False, 1e-05)), # updated constant table
1312   //              ('types', ()),
1313   //              ('register_size', 2)),
1314   //          (('arguments',
1315   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1316   //              None)),
1317   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
1318   //                  None)))),
1319   //              ('returns',
1320   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
1321   //                  None)),)))))
1322 
1323   //  bytecode with 3 specified arguments:
1324   //  (6,
1325   //      ('__torch__.m.forward',
1326   //          (('instructions',
1327   //              (('STOREN', 1, 2),
1328   //                  ('DROPR', 1, 0),
1329   //                  ('MOVE', 2, 0),
1330   //                  ('LOADC', 1, 0),
1331   //                  ('LOADC', 0, 0),
1332   //                  ('OP', 0, 0),
1333   //                  ('RET', 0, 0))),
1334   //              ('operators', (('aten::linalg_pinv', '', 3),)),
1335   //              ('constants', (True, 1e-05)),
1336   //              ('types', ()),
1337   //              ('register_size', 2)),
1338   //          (('arguments',
1339   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1340   //              None)),
1341   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
1342   //                  None)))),
1343   //              ('returns',
1344   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
1345   //                  None)),)))))
1346 }
1347 
TEST(LiteInterpreterTest,DefaultArgsTensorinvSpecifyDefault)1348 TEST(LiteInterpreterTest, DefaultArgsTensorinvSpecifyDefault) {
1349   // The second argument is specified, but the value is the same as the default
1350   // value. It's treated as "not specified" since the value can be fetched from
1351   // schema.
1352   Module m("m");
1353   m.define(R"(
1354     def forward(self, input):
1355       return torch.linalg_tensorinv(input, 2)
1356   )");
1357   torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
1358   auto arg_nums = code.op_to_num_specified_args();
1359   ASSERT_EQ(arg_nums.size(), 1);
1360   ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
1361   std::vector<torch::jit::IValue> inputs;
1362   const int N = 4;
1363   auto input = torch::rand({N, N, N, N});
1364   inputs.push_back(input);
1365   testLiteModuleCompareResultTensors(m, inputs);
1366 }
1367 
testDefaultArgsPinvWithOutArg(int num_args)1368 void testDefaultArgsPinvWithOutArg(int num_args) {
1369   Module m("m");
1370   if (num_args == 1) {
1371     m.define(R"(
1372       def forward(self, input):
1373         return torch.linalg_pinv(input, out=input)
1374     )");
1375   } else if (num_args == 2) {
1376     m.define(R"(
1377       def forward(self, input):
1378         return torch.linalg_pinv(input, 1e-5, out=input)
1379     )");
1380   } else if (num_args == 3) {
1381     m.define(R"(
1382       def forward(self, input):
1383         return torch.linalg_pinv(input, 1e-5, True, out=input)
1384     )");
1385   }
1386 
1387   const int N = 28;
1388   auto input = torch::range(1, N * N, 1);
1389   input[0] = 10000; // a more stable matrix
1390   input = input.view({N, N});
1391   auto ref = m.run_method("forward", input);
1392   TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
1393   TORCH_CHECK(input.equal(ref.toTensor()));
1394 }
1395 
TEST(LiteInterpreterTest,DefaultArgsPinvWithOutArg)1396 TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) {
1397   // Test with different number of specified arguments + out arg.
1398   // Arguments not specified take default value.
1399   for (int num_args = 1; num_args <= 3; ++num_args) {
1400     testDefaultArgsPinvWithOutArg(num_args);
1401   }
1402 }
1403 
TEST(LiteInterpreterTest,DefaultArgsWithOutArg)1404 TEST(LiteInterpreterTest, DefaultArgsWithOutArg) {
1405   Module m("m");
1406   m.define(R"(
1407     def forward(self, x, h):
1408       torch.add(x, h, out=x)
1409   )");
1410 
1411   std::vector<IValue> inputs;
1412   auto input_x = 2 * torch::ones({});
1413   auto input_h = torch::ones({});
1414   auto ref = m.run_method("forward", input_x, input_h);
1415 
1416   std::stringstream ss;
1417 
1418   m._save_for_mobile(ss, {}, true);
1419   mobile::Module bc = _load_for_mobile(ss);
1420   bc.run_method("forward", input_x, input_h);
1421   AT_ASSERT(input_x.equal(4 * torch::ones({})));
1422 
1423   auto ops = _get_model_ops_and_info(ss);
1424   auto op = ops.find("aten::add.out");
1425   TORCH_CHECK(
1426       op != ops.end() && op->second.num_schema_args.has_value() &&
1427       op->second.num_schema_args.value() == 3);
1428 }
1429 
TEST(LiteInterpreterTest,TestExceptionStackWithTwoLevelModuleHierarchy)1430 TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
1431   Module a("A");
1432   a.define(R"(
1433     def bar(self, x, y):
1434       return x + y
1435   )");
1436   Module b("B");
1437   b.register_module("A0", a);
1438   b.define(R"(
1439     def foo(self, x, y):
1440       return self.A0.bar(x, y) + 2
1441   )");
1442   Module c("C");
1443   c.register_module("B0", b);
1444   c.define(R"(
1445     def forward(self, x, y):
1446       return self.B0.foo(x, y) + 3
1447   )");
1448 
1449   std::vector<IValue> inputs;
1450   inputs.emplace_back(torch::rand({2, 4}));
1451   inputs.emplace_back(torch::rand({13, 9}));
1452 
1453   std::stringstream ss;
1454   c._save_for_mobile(ss, ExtraFilesMap(), true);
1455   auto lite_m = _load_for_mobile(ss);
1456   std::string error_pattern = R"(
1457   Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
1458 Traceback of TorchScript (most recent call last):
1459   File "<string>", line 3, in <unknown>
1460 
1461     def forward(self, x, y):
1462       return self.B0.foo(x, y) + 3
1463              ~~~~~~~~~~~ <--- HERE
1464 
1465   File "<string>", line 3, in foo
1466 
1467     def foo(self, x, y):
1468       return self.A0.bar(x, y) + 2
1469              ~~~~~~~~~~~ <--- HERE
1470 
1471   File "<string>", line 3, in bar
1472 
1473     def bar(self, x, y):
1474       return x + y
1475              ~~~~~ <--- HERE
1476   )";
1477   ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
1478 }
1479 #endif // !defined(FB_XPLAT_BUILD)
1480 
1481 namespace {
1482 static auto reg =
1483     torch::class_<TorchBindLiteInterpreterTestStruct>(
1484         "_TorchScriptTesting",
1485         "_LiteInterpreterTest")
1486         .def(torch::init<>())
1487         .def("get", &TorchBindLiteInterpreterTestStruct::get)
1488         .def_pickle(
1489             // __getattr__
1490             [](const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>&
__anond2aa1aa70502(const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>& self) 1491                    self) -> int64_t { return 0; },
1492             // __setattr__
__anond2aa1aa70602(int64_t state) 1493             [](int64_t state) {
1494               return c10::make_intrusive<TorchBindLiteInterpreterTestStruct>();
1495             });
1496 
1497 } // namespace
1498 
TEST(LiteInterpreterTest,OperatorCacheDifferentiatesDefaultArgs)1499 TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) {
1500   // Create 3 methods:
1501   //
1502   // 1. forward() returns a tensor with dtype=torch.int64 (4)
1503   // 2. forward2() returns a tensor with dtype=torch.float32 (6)
1504   // 3. forward3() returns a tensor with dtype=torch.float32 but
1505   //    the dtype is inferred by the input tensor's dtype
1506   //
1507   // If caching works correctly, then the result from the full-jit
1508   // module and the lite module will be the same. Otherwise, it
1509   // will be different if we don't correctly ignore the cache
1510   // entry for an operator that has a different number of
1511   // arguments.
1512   Module m("m");
1513   m.define(R"(
1514     def forward(self):
1515       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
1516       return ret1.fill_(25)
1517   )");
1518   m.define(R"(
1519     def forward2(self):
1520       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
1521       return ret1.fill_(32.0)
1522   )");
1523   m.define(R"(
1524     def forward3(self):
1525       ret1 = torch.new_empty(torch.zeros(10), [10])
1526       return ret1.fill_(12.0)
1527   )");
1528 
1529   std::vector<torch::jit::IValue> inputs;
1530   testLiteModuleCompareResultTensors(m, inputs, "forward");
1531   testLiteModuleCompareResultTensors(m, inputs, "forward2");
1532   testLiteModuleCompareResultTensors(m, inputs, "forward3");
1533 }
1534 
TEST(RunTimeTest,RuntimeCall)1535 TEST(RunTimeTest, RuntimeCall) {
1536   //     def call(x):
1537   //         return x + x
1538   //
1539   //     def forward(a):
1540   //         x = a + call(a)
1541   //         y = a + call(x)
1542   //         return y
1543 
1544   std::vector<IValue> instructionsCall{
1545       to_tuple({"STORE", 1, 0}),
1546       to_tuple({"LOAD", 1, 0}),
1547       to_tuple({"MOVE", 1, 0}),
1548       to_tuple({"LOADC", 0, 0}),
1549       to_tuple({"OP", 0, 0}),
1550       to_tuple({"RET", 0, 0}),
1551   };
1552   std::vector<IValue> instructionsFoo{
1553       to_tuple({"STORE", 1, 0}),
1554       to_tuple({"LOAD", 1, 0}),
1555       to_tuple({"LOAD", 1, 0}),
1556       to_tuple({"MOVE", 1, 0}),
1557       to_tuple({"CALL", 0, 0}),
1558       to_tuple({"LOADC", 0, 0}),
1559       to_tuple({"OP", 0, 0}),
1560       to_tuple({"CALL", 0, 0}),
1561       to_tuple({"LOADC", 0, 0}),
1562       to_tuple({"OP", 0, 0}),
1563       to_tuple({"RET", 0, 0}),
1564   };
1565   std::vector<IValue> operatorsFoo{
1566       to_tuple({"aten::add", "Tensor", 3}),
1567   };
1568   std::vector<IValue> constantsFoo{
1569       1,
1570   };
1571   std::vector<IValue> operatorsCall{
1572       to_tuple({"aten::add", "Tensor", 3}),
1573   };
1574   std::vector<IValue> constantsCall{
1575       1,
1576   };
1577 
1578   auto foo = std::make_unique<mobile::Function>(c10::QualifiedName("foo"));
1579   c10::ivalue::TupleElements debug_handles_m_tuple;
1580   parseInstructions(
1581       "foo",
1582       std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(),
1583       debug_handles_m_tuple,
1584       foo.get());
1585   parseOperators(
1586       std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(),
1587       1,
1588       foo.get());
1589   parseConstants(
1590       std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(),
1591       foo.get());
1592   const size_t rsize = 5;
1593   parseRegisterSize(rsize, foo.get());
1594 
1595   auto call = std::make_unique<mobile::Function>(c10::QualifiedName("call"));
1596   parseInstructions(
1597       "call",
1598       std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(),
1599       debug_handles_m_tuple,
1600       call.get());
1601   parseOperators(
1602       std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(),
1603       1,
1604       call.get());
1605   parseConstants(
1606       std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(),
1607       call.get());
1608   parseRegisterSize(rsize, call.get());
1609 
1610   foo->append_function(*call);
1611 
1612   std::vector<IValue> inputs{at::tensor(1)};
1613   foo->run(inputs);
1614   auto output = inputs[0];
1615   ASSERT_EQ(output, at::tensor(7));
1616 }
1617 
TEST(LiteInterpreterTest,OperatorSize1)1618 TEST(LiteInterpreterTest, OperatorSize1) {
1619   Module m("m");
1620   m.define(R"(
1621     def forward(self, input: Tensor, scale:float):
1622       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1623   )");
1624 
1625   std::stringstream ss;
1626   m._save_for_mobile(ss);
1627   mobile::Module bc = _load_for_mobile(ss);
1628   const auto& func = bc.get_method("forward").function();
1629   ASSERT_EQ(
1630       func.get_code().operator_input_sizes_.size(),
1631       func.get_code().operators_.size());
1632 }
1633 
TEST(LiteInterpreterTest,OperatorTest2)1634 TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
1635   const std::vector<std::string> test_programs{
1636       // test invoking a method with default parameter
1637       R"(
1638       def test_func(self, x, b : int = 4):
1639         return self.foo + x + b
1640       )",
1641       // inner method call with default parameter (gets inlined)
1642       R"(
1643       def add_with_default_arg(self, x, b : int = 4):
1644         return self.foo + x + b
1645       def test_func(self, x):
1646         return self.add_with_default_arg(x)  # invoke method w/ default arg
1647       )",
1648       // simple method call
1649       R"(
1650       def test_func(self, x):
1651         b = 4
1652         return self.foo + x + b
1653       )",
1654   };
1655   for (const auto& test_program : test_programs) {
1656     Module m("m");
1657     m.register_parameter("foo", torch::ones({}), false);
1658     m.define(test_program);
1659 
1660     std::stringstream ss;
1661     m._save_for_mobile(ss);
1662     mobile::Module bc = _load_for_mobile(ss);
1663     const auto& func = bc.get_method("test_func").function();
1664     ASSERT_EQ(
1665         func.get_code().operator_input_sizes_.size(),
1666         func.get_code().operators_.size());
1667   }
1668 }
1669 
1670 #if !defined FB_XPLAT_BUILD
1671 // The following test run in fbcode only
TEST(LiteInterpreterUpgraderTest,DivTensorV2)1672 TEST(LiteInterpreterUpgraderTest, DivTensorV2) {
1673   std::string filePath(__FILE__);
1674   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1675   test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl");
1676   /*
1677   (('__torch__.MyModule.forward',
1678     (('instructions',
1679       (('STOREN', 1, 3),
1680        ('DROPR', 1, 0),
1681        ('LOAD', 2, 0),
1682        ('LOAD', 3, 0),
1683        ('OP', 0, 0),
1684        ('LOAD', 2, 0),
1685        ('LOAD', 3, 0),
1686        ('OP', 1, 0),
1687        ('MOVE', 2, 0),
1688        ('MOVE', 3, 0),
1689        ('OP', 2, 0),
1690        ('TUPLE_CONSTRUCT', 3, 0),
1691        ('RET', 0, 0))),
1692      ('operators',
1693       (('aten::div', 'Tensor'),
1694        ('aten::div', 'Tensor'),
1695        ('aten::div', 'Tensor'))),
1696      ('constants', ()),
1697      ('types', ()),
1698      ('register_size', 3))),)
1699 
1700   */
1701   mobile::Module m_module = _load_for_mobile(test_model_file);
1702   auto intrsuction_list =
1703       m_module.get_method("forward").function().get_code().instructions_;
1704   uint64_t number_of_call_instruction = 0;
1705   for (auto& instruction : intrsuction_list) {
1706     number_of_call_instruction += (instruction.op == OpCode::CALL);
1707   }
1708   // 3 operators will use upgrader
1709   ASSERT_EQ(number_of_call_instruction, 3);
1710 
1711   std::vector<IValue> inputs = {
1712       IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1713   auto actual_output = m_module.forward(inputs);
1714   auto expect_output = 2.0 * torch::ones({1});
1715   auto actual_output_list = actual_output.toTuple()->elements();
1716   ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output));
1717 }
1718 
TEST(LiteInterpreterUpgraderTest,DivTensorOutV2)1719 TEST(LiteInterpreterUpgraderTest, DivTensorOutV2) {
1720   std::string filePath(__FILE__);
1721   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1722   test_model_file.append(
1723       "upgrader_models/test_versioned_div_tensor_out_v2.ptl");
1724   /*
1725   (('__torch__.MyModule.forward',
1726     (('instructions',
1727       (('STOREN', 1, 4),
1728        ('DROPR', 1, 0),
1729        ('MOVE', 2, 0),
1730        ('MOVE', 3, 0),
1731        ('MOVE', 4, 0),
1732        ('OP', 0, 0),
1733        ('RET', 0, 0))),
1734      ('operators', (('aten::div', 'out'),)),
1735      ('constants', ()),
1736      ('types', ()),
1737      ('register_size', 4))),)
1738   */
1739   mobile::Module m_module = _load_for_mobile(test_model_file);
1740 
1741   auto intrsuction_list =
1742       m_module.get_method("forward").function().get_code().instructions_;
1743   uint64_t number_of_call_instruction = 0;
1744   for (auto& instruction : intrsuction_list) {
1745     number_of_call_instruction += (instruction.op == OpCode::CALL);
1746   }
1747   // One operator will use upgrader
1748   ASSERT_EQ(number_of_call_instruction, 1);
1749 
1750   std::vector<IValue> inputs{
1751       IValue(6 * torch::ones({1})),
1752       IValue(3 * torch::ones({1})),
1753       IValue(torch::empty({1}))};
1754   m_module.forward(inputs);
1755   auto expect_output = 2.0 * torch::ones({1});
1756   auto actual_output = inputs[2].toTensor();
1757   // The out argument will be overwritten with the output
1758   ASSERT_TRUE(actual_output.equal(expect_output));
1759 }
1760 
TEST(LiteInterpreterUpgraderTest,DivTensorInplaceV2)1761 TEST(LiteInterpreterUpgraderTest, DivTensorInplaceV2) {
1762   std::string filePath(__FILE__);
1763   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1764   test_model_file.append(
1765       "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl");
1766   /*
1767   (('__torch__.MyModule.forward',
1768     (('instructions',
1769       (('STOREN', 1, 3),
1770        ('DROPR', 1, 0),
1771        ('MOVE', 2, 0),
1772        ('MOVE', 3, 0),
1773        ('OP', 0, 0),
1774        ('RET', 0, 0))),
1775      ('operators', (('aten::div_', 'Tensor'),)),
1776      ('constants', ()),
1777      ('types', ()),
1778      ('register_size', 3))),)
1779   */
1780   mobile::Module m_module = _load_for_mobile(test_model_file);
1781 
1782   auto intrsuction_list =
1783       m_module.get_method("forward").function().get_code().instructions_;
1784   uint64_t number_of_call_instruction = 0;
1785   for (auto& instruction : intrsuction_list) {
1786     number_of_call_instruction += (instruction.op == OpCode::CALL);
1787   }
1788   // One operator will use upgrader
1789   ASSERT_EQ(number_of_call_instruction, 1);
1790 
1791   std::vector<IValue> inputs{
1792       IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1793   m_module.forward(inputs);
1794   auto expect_output = 2.0 * torch::ones({1});
1795   auto actual_output = inputs[0].toTensor();
1796   // The out argument will be overwritten with the output
1797   ASSERT_TRUE(actual_output.equal(expect_output));
1798 }
1799 
TEST(LiteInterpreterUpgraderTest,DivScalarFloatV2)1800 TEST(LiteInterpreterUpgraderTest, DivScalarFloatV2) {
1801   std::string filePath(__FILE__);
1802   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1803   test_model_file.append(
1804       "upgrader_models/test_versioned_div_scalar_float_v2.ptl");
1805   /*
1806   (('__torch__.MyModuleFloat.forward',
1807     (('instructions',
1808     (('STOREN', 1, 3),
1809     ('DROPR', 1, 0),
1810     ('MOVE', 2, 0),
1811     ('MOVE', 3, 0),
1812     ('OP', 0, 0),
1813     ('RET', 0, 0))),
1814     ('operators', (('aten::div', 'Scalar'),)),
1815     ('constants', ()),
1816     ('types', ()),
1817     ('register_size', 3))),)
1818   */
1819 
1820   mobile::Module m_module = _load_for_mobile(test_model_file);
1821 
1822   auto intrsuction_list =
1823       m_module.get_method("forward").function().get_code().instructions_;
1824   uint64_t number_of_call_instruction = 0;
1825   for (auto& instruction : intrsuction_list) {
1826     number_of_call_instruction += (instruction.op == OpCode::CALL);
1827   }
1828   // One operator will use upgrader
1829   ASSERT_EQ(number_of_call_instruction, 1);
1830 
1831   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1832   auto output = m_module.forward(inputs);
1833   auto expect_output = 2.0 * torch::ones({1});
1834   auto actual_output = output.toTensor();
1835 
1836   // The out argument will be overwritten with the output
1837   ASSERT_TRUE(actual_output.equal(expect_output));
1838 }
1839 
TEST(LiteInterpreterUpgraderTest,DivScalarReciprocalFloatV2)1840 TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalFloatV2) {
1841   std::string filePath(__FILE__);
1842   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1843   test_model_file.append(
1844       "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl");
1845   /*
1846   (('__torch__.MyModuleFloat.forward',
1847     (('instructions',
1848       (('STOREN', 1, 3),
1849       ('DROPR', 1, 0),
1850       ('MOVE', 2, 0),
1851       ('OP', 0, 0),
1852       ('MOVE', 3, 0),
1853       ('OP', 1, 0),
1854       ('RET', 0, 0))),
1855     ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1856     ('constants', ()),
1857     ('types', ()),
1858     ('register_size', 3))),)
1859   */
1860   mobile::Module m_module = _load_for_mobile(test_model_file);
1861 
1862   auto intrsuction_list =
1863       m_module.get_method("forward").function().get_code().instructions_;
1864   uint64_t number_of_call_instruction = 0;
1865   for (auto& instruction : intrsuction_list) {
1866     number_of_call_instruction += (instruction.op == OpCode::CALL);
1867   }
1868   // No operator will use upgrader
1869   ASSERT_EQ(number_of_call_instruction, 0);
1870 
1871   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1872   auto output = m_module.forward(inputs);
1873   auto expect_output = 0.5 * torch::ones({1});
1874   auto actual_output = output.toTensor();
1875   std::cout << "expect output: " << expect_output;
1876   std::cout << "actual output: " << actual_output;
1877   // The out argument will be overwritten with the output
1878   ASSERT_TRUE(actual_output.equal(expect_output));
1879 }
1880 
TEST(LiteInterpreterUpgraderTest,DivScalarReciprocalIntV2)1881 TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalIntV2) {
1882   std::string filePath(__FILE__);
1883   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1884   test_model_file.append(
1885       "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl");
1886   /*
1887   (('__torch__.MyModuleInt.forward',
1888   (('instructions',
1889     (('STOREN', 1, 3),
1890      ('DROPR', 1, 0),
1891      ('MOVE', 2, 0),
1892      ('OP', 0, 0),
1893      ('MOVE', 3, 0),
1894      ('OP', 1, 0),
1895      ('RET', 0, 0))),
1896    ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1897    ('constants', ()),
1898    ('types', ()),
1899    ('register_size', 3))),)
1900   */
1901   mobile::Module m_module = _load_for_mobile(test_model_file);
1902 
1903   auto intrsuction_list =
1904       m_module.get_method("forward").function().get_code().instructions_;
1905   uint64_t number_of_call_instruction = 0;
1906   for (auto& instruction : intrsuction_list) {
1907     number_of_call_instruction += (instruction.op == OpCode::CALL);
1908   }
1909   // No operator will use upgrader
1910   ASSERT_EQ(number_of_call_instruction, 0);
1911 
1912   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1913   auto output = m_module.forward(inputs);
1914   auto expect_output = 0.5 * torch::ones({1});
1915   auto actual_output = output.toTensor();
1916 
1917   // The out argument will be overwritten with the output
1918   ASSERT_TRUE(actual_output.equal(expect_output));
1919 }
1920 
TEST(LiteInterpreterUpgraderTest,DivScalarScalarV2)1921 TEST(LiteInterpreterUpgraderTest, DivScalarScalarV2) {
1922   std::string filePath(__FILE__);
1923   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1924   test_model_file.append(
1925       "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl");
1926   /*
1927   (('__torch__.MyModule.forward',
1928     (('instructions',
1929       (('STOREN', 1, 5),
1930       ('DROPR', 1, 0),
1931       ('LOAD', 2, 0),
1932       ('LOAD', 3, 0),
1933       ('OP', 0, 0),
1934       ('MOVE', 2, 0),
1935       ('LOAD', 4, 0),
1936       ('OP', 1, 0),
1937       ('LOAD', 3, 0),
1938       ('MOVE', 4, 0),
1939       ('OP', 2, 0),
1940       ('MOVE', 3, 0),
1941       ('MOVE', 5, 0),
1942       ('OP', 3, 0),
1943       ('TUPLE_CONSTRUCT', 4, 0),
1944       ('RET', 0, 0))),
1945     ('operators',
1946       (('aten::div', ''),
1947       ('aten::div', 'float'),
1948       ('aten::div', ''),
1949       ('aten::div', 'int'))),
1950     ('constants', ()),
1951     ('types', ()),
1952     ('register_size', 5))),)
1953   */
1954   mobile::Module m_module = _load_for_mobile(test_model_file);
1955   auto intrsuction_list =
1956       m_module.get_method("forward").function().get_code().instructions_;
1957   uint64_t number_of_call_instruction = 0;
1958   for (auto& instruction : intrsuction_list) {
1959     number_of_call_instruction += (instruction.op == OpCode::CALL);
1960   }
1961   // No operator will use upgrader
1962   ASSERT_EQ(number_of_call_instruction, 0);
1963 
1964   std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)};
1965   auto output = m_module.forward(inputs);
1966   auto output_list = output.toTupleRef().elements();
1967   auto expect_output = std::vector<IValue>(
1968       {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)});
1969   // auto actual_output = output.toTensor();
1970   for (size_t i = 0; i < expect_output.size(); i++) {
1971     ASSERT_EQ(output_list[i], expect_output[i]);
1972   }
1973 }
1974 
TEST(LiteInterpreterUpgraderTest,DivScalarIntV2)1975 TEST(LiteInterpreterUpgraderTest, DivScalarIntV2) {
1976   std::string filePath(__FILE__);
1977   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1978   test_model_file.append(
1979       "upgrader_models/test_versioned_div_scalar_int_v2.ptl");
1980   /*
1981   (('__torch__.MyModuleInt.forward',
1982     (('instructions',
1983       (('STOREN', 1, 3),
1984       ('DROPR', 1, 0),
1985       ('MOVE', 2, 0),
1986       ('MOVE', 3, 0),
1987       ('OP', 0, 0),
1988       ('RET', 0, 0))),
1989     ('operators', (('aten::div', 'Scalar'),)),
1990     ('constants', ()),
1991     ('types', ()),
1992     ('register_size', 3))),)
1993   */
1994   mobile::Module m_module = _load_for_mobile(test_model_file);
1995 
1996   auto intrsuction_list =
1997       m_module.get_method("forward").function().get_code().instructions_;
1998   uint64_t number_of_call_instruction = 0;
1999   for (auto& instruction : intrsuction_list) {
2000     number_of_call_instruction += (instruction.op == OpCode::CALL);
2001   }
2002   // One operator will use upgrader
2003   ASSERT_EQ(number_of_call_instruction, 1);
2004 
2005   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
2006   auto output = m_module.forward(inputs);
2007   auto expect_output = 2.0 * torch::ones({1});
2008   auto actual_output = output.toTensor();
2009 
2010   // The out argument will be overwritten with the output
2011   ASSERT_TRUE(actual_output.equal(expect_output));
2012 }
2013 
TEST(LiteInterpreterUpgraderTest,DivScalarInplaceFloatV2)2014 TEST(LiteInterpreterUpgraderTest, DivScalarInplaceFloatV2) {
2015   std::string filePath(__FILE__);
2016   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
2017   test_model_file.append(
2018       "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl");
2019   /*
2020   (('__torch__.MyModuleFloat.forward',
2021     (('instructions',
2022       (('STOREN', 1, 3),
2023       ('DROPR', 1, 0),
2024       ('MOVE', 2, 0),
2025       ('MOVE', 3, 0),
2026       ('OP', 0, 0),
2027       ('RET', 0, 0))),
2028     ('operators', (('aten::div_', 'Scalar'),)),
2029     ('constants', ()),
2030     ('types', ()),
2031     ('register_size', 3))),)
2032   */
2033 
2034   mobile::Module m_module = _load_for_mobile(test_model_file);
2035 
2036   auto intrsuction_list =
2037       m_module.get_method("forward").function().get_code().instructions_;
2038   uint64_t number_of_call_instruction = 0;
2039   for (auto& instruction : intrsuction_list) {
2040     number_of_call_instruction += (instruction.op == OpCode::CALL);
2041   }
2042   // One operator will use upgrader
2043   ASSERT_EQ(number_of_call_instruction, 1);
2044 
2045   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
2046   auto output = m_module.forward(inputs);
2047   auto expect_output = 2.0 * torch::ones({1});
2048   auto actual_output = output.toTensor();
2049 
2050   // The out argument will be overwritten with the output
2051   ASSERT_TRUE(actual_output.equal(expect_output));
2052 }
2053 
TEST(LiteInterpreterUpgraderTest,DivScalarInplaceIntV2)2054 TEST(LiteInterpreterUpgraderTest, DivScalarInplaceIntV2) {
2055   std::string filePath(__FILE__);
2056   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
2057   test_model_file.append(
2058       "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl");
2059   /*
2060   (('__torch__.MyModuleInt.forward',
2061     (('instructions',
2062       (('STOREN', 1, 3),
2063        ('DROPR', 1, 0),
2064        ('MOVE', 2, 0),
2065        ('MOVE', 3, 0),
2066        ('OP', 0, 0),
2067        ('RET', 0, 0))),
2068      ('operators', (('aten::div_', 'Scalar'),)),
2069      ('constants', ()),
2070      ('types', ()),
2071      ('register_size', 3))),)
2072   */
2073 
2074   mobile::Module m_module = _load_for_mobile(test_model_file);
2075 
2076   auto intrsuction_list =
2077       m_module.get_method("forward").function().get_code().instructions_;
2078   uint64_t number_of_call_instruction = 0;
2079   for (auto& instruction : intrsuction_list) {
2080     number_of_call_instruction += (instruction.op == OpCode::CALL);
2081   }
2082   // One operator will use upgrader
2083   ASSERT_EQ(number_of_call_instruction, 1);
2084 
2085   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
2086   auto output = m_module.forward(inputs);
2087   auto expect_output = 2.0 * torch::ones({1});
2088   auto actual_output = output.toTensor();
2089 
2090   // The out argument will be overwritten with the output
2091   ASSERT_TRUE(actual_output.equal(expect_output));
2092 }
2093 
2094 #endif // !defined(FB_XPLAT_BUILD)
2095 
TEST(LiteInterpreterUpgraderTest,Upgrader)2096 TEST(LiteInterpreterUpgraderTest, Upgrader) {
2097   std::vector<mobile::Function> upgrader_functions;
2098 
2099   for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
2100     byteCodeFunctionWithOperator.function.initialize_operators(true);
2101     ASSERT_EQ(
2102         byteCodeFunctionWithOperator.function.get_code().operators_.size(),
2103         byteCodeFunctionWithOperator.function.get_code().op_names_.size());
2104     if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
2105       for (const auto& op : byteCodeFunctionWithOperator.operators) {
2106         byteCodeFunctionWithOperator.function.append_operator(
2107             op.name, op.overload_name, op.num_specified_args);
2108       }
2109     }
2110     upgrader_functions.push_back(byteCodeFunctionWithOperator.function);
2111   }
2112 
2113   ASSERT_EQ(getUpgraderBytecodeList().size(), upgrader_functions.size());
2114 }
2115 
enumerateTupleType(size_t depth,std::vector<TypePtr> & current,const std::vector<TypePtr> & candidates,std::vector<TypePtr> & out)2116 void enumerateTupleType(
2117     size_t depth,
2118     std::vector<TypePtr>& current,
2119     const std::vector<TypePtr>& candidates,
2120     std::vector<TypePtr>& out) {
2121   static std::vector<std::string> fieldNames;
2122   if (depth > fieldNames.size()) {
2123     fieldNames.reserve(depth);
2124     for (size_t i = fieldNames.size(); i < depth; i++) {
2125       fieldNames.push_back("field" + std::to_string(i));
2126     }
2127   }
2128   if (depth == 0) {
2129     out.push_back(TupleType::create(current));
2130     while (fieldNames.size() > current.size()) {
2131       fieldNames.pop_back();
2132     }
2133     out.push_back(TupleType::createNamed("NamedTuple", fieldNames, current));
2134     return;
2135   }
2136   for (const auto& type : candidates) {
2137     if (containsAnyType(type)) {
2138       continue;
2139     }
2140     current.push_back(type);
2141     enumerateTupleType(depth - 1, current, candidates, out);
2142     current.pop_back();
2143   }
2144 }
2145 
2146 class LiteInterpreterDynamicTypeTestFixture
2147     : public ::testing::TestWithParam<size_t> {
2148  protected:
SetUp()2149   void SetUp() override {
2150     cu = std::make_shared<CompilationUnit>();
2151     std::vector<TypePtr> keyTypes = {
2152         AnyType::get(),
2153         IntType::get(),
2154         BoolType::get(),
2155         FloatType::get(),
2156         ComplexType::get(),
2157         StringType::get(),
2158         TensorType::get(),
2159         DeviceObjType::get(),
2160     };
2161     types = {
2162         NoneType::get(),
2163         NumberType::get(),
2164         ClassType::create("__torch__.TestClass1", cu),
2165         ClassType::create("__torch__.TestClass2", cu),
2166         AnyListType::get(),
2167         AnyTupleType::get(),
2168         StreamObjType::get(),
2169         CapsuleType::get(),
2170         GeneratorType::get(),
2171         StorageType::get(),
2172         VarType::create("t"),
2173         VarType::create("v"),
2174         AnyClassType::get()};
2175     std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
2176     auto expandTypes = [&](size_t tupleSize) {
2177       std::vector<TypePtr> nested;
2178       for (const auto& type : types) {
2179         if (!(type == AnyType::get())) {
2180           nested.emplace_back(ListType::create(type));
2181           if (!(type == NoneType::get() ||
2182                 type->kind() == OptionalType::Kind)) {
2183             nested.emplace_back(OptionalType::create(type));
2184           }
2185         }
2186         for (const auto& keyType : keyTypes) {
2187           nested.emplace_back(DictType::create(keyType, type));
2188         }
2189       }
2190       std::vector<TypePtr> tmp;
2191       enumerateTupleType(tupleSize, tmp, types, nested);
2192       std::move(
2193           std::begin(nested), std::end(nested), std::back_inserter(types));
2194     };
2195     expandTypes(1);
2196     expandTypes(1);
2197   }
2198   std::shared_ptr<CompilationUnit> cu;
2199   std::vector<TypePtr> types;
2200 
2201  public:
2202   static constexpr size_t kNumSplits = 10;
2203 };
2204 
2205 /**
2206  * Enumerate all possible JIT types appearing in mobile runtime, and test
2207  * whether subtyping relation is preserved after one of the JIT types is
2208  * converted to DynamicType.
2209  *
2210  * We firstly enumerate all "base" types in a vector, and implement
2211  * expandTypes() to enumerate container types one "level" up for a given set
2212  * of types. We call expandTypes() twice to test types nested less or equal
2213  * to two levels. e.g. List[Optional[Tensor]], Optional[Dict[Int, Bool]], etc.
2214  */
TEST_P(LiteInterpreterDynamicTypeTestFixture,Conformance)2215 TEST_P(LiteInterpreterDynamicTypeTestFixture, Conformance) {
2216   size_t num = types.size() / LiteInterpreterDynamicTypeTestFixture::kNumSplits;
2217   size_t begin = num * GetParam();
2218   size_t end = std::min(types.size(), begin + num);
2219   for (const auto& a : types) {
2220     auto da = DynamicType::create(*a);
2221     for (size_t i = begin; i < end; i++) {
2222       const auto& b = types[i];
2223       bool result = a->isSubtypeOf(*b);
2224       EXPECT_EQ(result, da->isSubtypeOf(*b));
2225       result = b->isSubtypeOf(*a);
2226       EXPECT_EQ(result, b->isSubtypeOf(*da));
2227     }
2228   }
2229 }
2230 
2231 INSTANTIATE_TEST_SUITE_P(
2232     PyTorch,
2233     LiteInterpreterDynamicTypeTestFixture,
2234     ::testing::Range(
2235         static_cast<size_t>(0),
2236         LiteInterpreterDynamicTypeTestFixture::kNumSplits));
2237 
2238 } // namespace jit
2239 } // namespace torch
2240