1 #include <test/cpp/jit/test_utils.h>
2
3 #include <c10/core/TensorOptions.h>
4 #include <gtest/gtest.h>
5 #include <torch/csrc/autograd/generated/variable_factories.h>
6 #include <torch/csrc/jit/api/module.h>
7 #include <torch/csrc/jit/frontend/resolver.h>
8 #include <torch/csrc/jit/mobile/compatibility/backport.h>
9 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
10 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
11 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
12 #include <torch/csrc/jit/mobile/import.h>
13 #include <torch/csrc/jit/mobile/interpreter.h>
14 #include <torch/csrc/jit/mobile/module.h>
15 #include <torch/csrc/jit/mobile/parse_bytecode.h>
16 #include <torch/csrc/jit/mobile/parse_operators.h>
17 #include <torch/csrc/jit/mobile/upgrader_mobile.h>
18 #include <torch/csrc/jit/serialization/export.h>
19 #include <torch/csrc/jit/serialization/import.h>
20 #include <torch/custom_class.h>
21 #include <torch/torch.h>
22
23 #include <torch/csrc/jit/serialization/import_export_functions.h>
24 #include <unordered_set>
25
26 // Tests go in torch::jit
27 namespace torch {
28 namespace jit {
29
TEST(LiteInterpreterTest,UpsampleNearest2d)30 TEST(LiteInterpreterTest, UpsampleNearest2d) {
31 Module m("m");
32 m.define(R"(
33 def forward(self, input: Tensor, scale:float):
34 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
35 )");
36
37 std::vector<IValue> inputs;
38 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
39 inputs.emplace_back(at::Scalar(2.0));
40 auto ref = m.forward(inputs);
41
42 std::stringstream ss;
43 m._save_for_mobile(ss);
44 mobile::Module bc = _load_for_mobile(ss);
45 IValue res;
46 res = bc.forward(inputs);
47
48 auto resd = res.toTensor();
49 auto refd = ref.toTensor();
50 ASSERT_TRUE(resd.equal(refd));
51 }
52
TEST(LiteInterpreterTest,CheckAttrAccess)53 TEST(LiteInterpreterTest, CheckAttrAccess) {
54 Module m("m");
55 m.register_attribute("mobile_optimized", BoolType::get(), true);
56
57 std::stringstream ss;
58 m._save_for_mobile(ss);
59 mobile::Module bc = _load_for_mobile(ss);
60 bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
61
62 AT_ASSERT(mobile_optimized);
63 m.setattr("mobile_optimized", false);
64 ss = std::stringstream();
65 m._save_for_mobile(ss);
66 bc = _load_for_mobile(ss);
67 mobile_optimized = bc.attr("mobile_optimized", false).toBool();
68
69 AT_ASSERT(!mobile_optimized);
70 }
71
TEST(LiteInterpreterTest,MethodInvocation)72 TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
73 const std::vector<std::string> test_programs{
74 // test invoking a method with default parameter
75 R"(
76 def test_func(self, x, b : int = 4):
77 return self.foo + x + b
78 )",
79 // inner method call with default parameter (gets inlined)
80 R"(
81 def add_with_default_arg(self, x, b : int = 4):
82 return self.foo + x + b
83 def test_func(self, x):
84 return self.add_with_default_arg(x) # invoke method w/ default arg
85 )",
86 // simple method call
87 R"(
88 def test_func(self, x):
89 b = 4
90 return self.foo + x + b
91 )",
92 };
93 for (const auto& test_program : test_programs) {
94 Module m("m");
95 m.register_parameter("foo", torch::ones({}), false);
96 m.define(test_program);
97
98 const int fortyTwo = 42; // (keep linter happy)
99 auto minput = fortyTwo * torch::ones({});
100 auto ref = m.run_method("test_func", minput);
101
102 std::stringstream ss;
103 m._save_for_mobile(ss);
104 mobile::Module bc = _load_for_mobile(ss);
105 const auto& test_func = bc.get_method("test_func");
106 IValue res;
107 for (int i = 0; i < 3; ++i) {
108 res = test_func({minput});
109 }
110
111 auto resd = res.toTensor().item<float>();
112 auto refd = ref.toTensor().item<float>();
113 AT_ASSERT(resd == refd);
114 }
115 }
116
TEST(LiteInterpreterTest,Conv)117 TEST(LiteInterpreterTest, Conv) {
118 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
119 if (s && strcmp(s, "1") == 0)
120 return;
121
122 std::vector<torch::jit::IValue> inputs;
123
124 Module m("m");
125 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
126 m.register_parameter("bias", torch::ones({20}), false);
127 m.define(R"(
128 def forward(self, input):
129 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
130 )");
131
132 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
133 inputs.push_back(torch::ones({1, 1, 28, 28}));
134
135 auto outputref = m.forward(inputs).toTensor();
136
137 std::stringstream ss;
138 m._save_for_mobile(ss);
139 mobile::Module bc = _load_for_mobile(ss);
140 IValue res;
141 for (int i = 0; i < 3; ++i) {
142 res = bc.get_method("forward")(inputs);
143 }
144 auto output = res.toTensor();
145 AT_ASSERT(outputref.dim() == output.dim());
146 AT_ASSERT(
147 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
148 }
149
TEST(LiteInterpreterTest,Inline)150 TEST(LiteInterpreterTest, Inline) {
151 Module m("m");
152 m.define(R"JIT(
153 def foo1(self, x):
154 return x + 1
155
156 def foo2(self, x):
157 return self.foo1(x) + 2
158
159 def foo3(self, x):
160 return self.foo2(x) + 3
161 )JIT");
162 std::stringstream ss;
163 m._save_for_mobile(ss);
164 mobile::Module bc = _load_for_mobile(ss);
165 std::vector<torch::jit::IValue> inputs({torch::ones({})});
166 auto output = bc.get_method("foo3")(inputs);
167 AT_ASSERT(output.toTensor().item<float>() == 7.0);
168 }
169
TEST(LiteInterpreterTest,Tuple)170 TEST(LiteInterpreterTest, Tuple) {
171 Module m("m");
172 m.define(R"JIT(
173 def foo(self, x):
174 return (1, 2, x + 3)
175
176 def forward(self, x):
177 tuple = self.foo(x)
178 return tuple
179 )JIT");
180 std::stringstream ss;
181 m._save_for_mobile(ss);
182 mobile::Module bc = _load_for_mobile(ss);
183 std::vector<torch::jit::IValue> inputs({torch::ones({})});
184 auto output = bc.get_method("forward")(inputs);
185 AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
186 }
187
TEST(LiteInterpreterTest,AtenFormat)188 TEST(LiteInterpreterTest, AtenFormat) {
189 Module m("m");
190 m.define(R"""(
191 def forward(self, fmt:str="first {} {}", num:str="abc"):
192 x = 2
193 x = x * x
194 return fmt.format(num, x)
195 )""");
196 std::stringstream ss;
197 m._save_for_mobile(ss);
198 mobile::Module bc = _load_for_mobile(ss);
199 std::vector<torch::jit::IValue> inputs;
200 auto output_bc = bc.get_method("forward")(inputs);
201 auto output_m = m.get_method("forward")(inputs);
202 // std::cout << output_m.toStringRef() << "\n"
203 // << output_bc.toStringRef() << std::endl;
204 AT_ASSERT(output_m.toStringRef() == output_bc.toStringRef());
205 }
206
TEST(LiteInterpreterTest,PrimDevice)207 TEST(LiteInterpreterTest, PrimDevice) {
208 Module m("m");
209 m.define(R"""(
210 def forward(self, x:torch.Tensor):
211 return x.device
212 )""");
213 std::stringstream ss;
214 m._save_for_mobile(ss);
215 mobile::Module bc = _load_for_mobile(ss);
216 std::vector<torch::jit::IValue> inputs;
217 auto minput = 3.5 * torch::ones({});
218 inputs.emplace_back(minput);
219 auto output_bc = bc.get_method("forward")(inputs);
220 auto output_m = m.get_method("forward")(inputs);
221 AT_ASSERT(output_bc.toDevice().str() == output_m.toDevice().str());
222 }
223
TEST(LiteInterpreterTest,Dict)224 TEST(LiteInterpreterTest, Dict) {
225 Module m("m");
226 m.define(R"JIT(
227 def foo(self, x):
228 return {"result": x + 1}
229
230 def forward(self, x):
231 d = self.foo(x)
232 return d
233 )JIT");
234 std::stringstream ss;
235 m._save_for_mobile(ss);
236 mobile::Module bc = _load_for_mobile(ss);
237 std::vector<torch::jit::IValue> inputs({torch::ones({})});
238 auto output = bc.get_method("forward")(inputs);
239 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
240 }
241
TEST(LiteInterpreterTest,List)242 TEST(LiteInterpreterTest, List) {
243 Module m("m");
244 m.define(R"JIT(
245 def foo(self, x):
246 return [x + 2]
247
248 def forward(self, x):
249 d = self.foo(x)
250 return d
251 )JIT");
252 std::stringstream ss;
253 m._save_for_mobile(ss);
254 mobile::Module bc = _load_for_mobile(ss);
255 std::vector<torch::jit::IValue> inputs({torch::ones({})});
256 auto output = bc.get_method("forward")(inputs);
257 auto server_output = m.forward(inputs);
258 EXPECT_EQ(output.toList().get(0).toTensor().item().toInt(), 3);
259 EXPECT_EQ(output, server_output);
260 }
261
TEST(LiteInterpreterTest,PrimOverload)262 TEST(LiteInterpreterTest, PrimOverload) {
263 /*
264 // temporarily disabled
265 script::Module m("m");
266 m.define(R"JIT(
267 def forward(self, x):
268 result = [1, 2]
269 result.append(3)
270 return result
271 )JIT");
272 std::stringstream ss;
273 m._save_for_mobile(ss);
274 mobile::Module bc = _load_for_mobile(ss);
275 std::vector<torch::jit::IValue> inputs({torch::ones({})});
276 auto output = bc.get_method("forward")(inputs);
277 AT_ASSERT(output.toIntList()[2] == 3);
278 */
279 }
280
TEST(LiteInterpreterTest,Prim)281 TEST(LiteInterpreterTest, Prim) {
282 Module m("m");
283 m.define(R"JIT(
284 def forward(self, x):
285 return int(x)
286 )JIT");
287
288 std::vector<IValue> inputs;
289 auto minput = 3.5 * torch::ones({});
290 inputs.emplace_back(minput);
291 auto ref = m.run_method("forward", minput);
292
293 std::stringstream ss;
294 m._save_for_mobile(ss);
295 mobile::Module bc = _load_for_mobile(ss);
296 IValue res;
297 for (int i = 0; i < 3; ++i) {
298 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
299 auto bcinputs = inputs;
300 res = bc.get_method("forward")(bcinputs);
301 }
302
303 auto resi = res.toInt();
304 auto refi = ref.toInt();
305 AT_ASSERT(resi == refi);
306 }
307
TEST(LiteInterpreterTest,PrimScalar)308 TEST(LiteInterpreterTest, PrimScalar) {
309 Module m("m");
310 m.define(R"JIT(
311 def forward(self, x):
312 return int(x.item())
313 )JIT");
314
315 std::vector<IValue> inputs;
316 auto minput = 3.5 * torch::ones({});
317 inputs.emplace_back(minput);
318 auto ref = m.run_method("forward", minput);
319
320 std::stringstream ss;
321 m._save_for_mobile(ss);
322 mobile::Module bc = _load_for_mobile(ss);
323 IValue res;
324 for (int i = 0; i < 3; ++i) {
325 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
326 auto bcinputs = inputs;
327 res = bc.get_method("forward")(bcinputs);
328 }
329
330 auto resi = res.toInt();
331 auto refi = ref.toInt();
332 AT_ASSERT(resi == refi);
333 }
334
TEST(LiteInterpreterTest,LoadOrigJit)335 TEST(LiteInterpreterTest, LoadOrigJit) {
336 Module m("m");
337 m.register_parameter("foo", torch::ones({}), false);
338 m.define(R"(
339 def forward(self, x):
340 b = 4
341 return self.foo + x + b
342 )");
343 std::stringstream ss;
344 m.save(ss);
345 ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found");
346 }
347
TEST(LiteInterpreterTest,WrongMethodName)348 TEST(LiteInterpreterTest, WrongMethodName) {
349 Module m("m");
350 m.register_parameter("foo", torch::ones({}), false);
351 m.define(R"(
352 def add(self, x):
353 b = 4
354 return self.foo + x + b
355 )");
356 std::stringstream ss;
357 m._save_for_mobile(ss);
358 mobile::Module bc = _load_for_mobile(ss);
359 std::vector<IValue> inputs;
360 auto minput = 5 * torch::ones({});
361 inputs.emplace_back(minput);
362 ASSERT_THROWS_WITH_MESSAGE(
363 bc.get_method("forward")(inputs), "is not defined");
364 }
365
TEST(LiteInterpreterTest,SetState)366 TEST(LiteInterpreterTest, SetState) {
367 Module m("m");
368 m.register_parameter("foo", torch::ones({}), false);
369 m.define(R"(
370 def __getstate__(self):
371 return self.foo + self.foo
372 def __setstate__(self, a):
373 self.foo = a
374 def forward(self, x):
375 b = 4
376 return self.foo + x + b
377 )");
378
379 std::vector<IValue> inputs;
380 auto minput = 5 * torch::ones({});
381 inputs.emplace_back(minput);
382
383 std::stringstream ms;
384 m.save(ms);
385 auto loaded_m = load(ms);
386 auto ref = loaded_m.run_method("forward", minput);
387
388 std::stringstream ss;
389 m._save_for_mobile(ss);
390 mobile::Module bc = _load_for_mobile(ss);
391 IValue res;
392 for (int i = 0; i < 3; ++i) {
393 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
394 auto bcinputs = inputs;
395 res = bc.get_method("forward")(bcinputs);
396 }
397
398 auto resd = res.toTensor().item<float>();
399 auto refd = ref.toTensor().item<float>();
400 AT_ASSERT(resd == refd);
401 }
402
403 class TorchBindLiteInterpreterTestStruct
404 : public torch::jit::CustomClassHolder {
405 public:
get(at::Tensor t)406 std::string get(at::Tensor t) {
407 std::stringstream ss;
408 ss << "Hello! Your tensor has ";
409 ss << t.numel();
410 ss << " elements!";
411 return ss.str();
412 }
413 };
414
415 namespace {
416 struct ClassNamespaceValue : public SugaredValue {
ClassNamespaceValuetorch::jit::__anond2aa1aa70111::ClassNamespaceValue417 explicit ClassNamespaceValue(c10::QualifiedName name)
418 : basename_(std::move(name)) {}
419
attrtorch::jit::__anond2aa1aa70111::ClassNamespaceValue420 std::shared_ptr<SugaredValue> attr(
421 const SourceRange& loc,
422 GraphFunction& m,
423 const std::string& name) override {
424 const auto fullName = c10::QualifiedName(basename_, name);
425
426 // Check to see if it is a custom class.
427 if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
428 return std::make_shared<ClassValue>(custom_class);
429 }
430
431 // If it's not a custom class, assume it's another namespace
432 // NOLINTNEXTLINE(performance-move-const-arg)
433 return std::make_shared<ClassNamespaceValue>(std::move(fullName));
434 }
435
kindtorch::jit::__anond2aa1aa70111::ClassNamespaceValue436 std::string kind() const override {
437 return "Class Namespace";
438 }
439
440 private:
441 c10::QualifiedName basename_;
442 };
443
444 struct TestModuleResolver : public Resolver {
resolveValuetorch::jit::__anond2aa1aa70111::TestModuleResolver445 std::shared_ptr<SugaredValue> resolveValue(
446 const std::string& name,
447 GraphFunction& m,
448 const SourceRange& loc) override {
449 if (name == "torch") {
450 return std::make_shared<BuiltinModule>("aten");
451 } else if (name == "__torch__") {
452 return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
453 }
454
455 return nullptr;
456 }
457
resolveTypetorch::jit::__anond2aa1aa70111::TestModuleResolver458 TypePtr resolveType(const std::string& name, const SourceRange& loc)
459 override {
460 return nullptr;
461 }
462 };
463 } // namespace
464
TEST(LiteInterpreterTest,BuiltinClass)465 TEST(LiteInterpreterTest, BuiltinClass) {
466 script::Module m("m");
467
468 auto cls = getCustomClass(
469 "__torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest");
470 TORCH_INTERNAL_ASSERT(cls);
471 c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
472 m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
473
474 m.register_parameter("foo", torch::ones({}), false);
475 m.define(
476 R"(
477 def __getstate__(self):
478 return 1
479 def __setstate__(self, a):
480 self.my_obj = __torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest()
481
482 def forward(self, x) -> str:
483 return self.my_obj.get(x)
484 )",
485 std::make_shared<TestModuleResolver>());
486
487 std::stringstream ss;
488 m._save_for_mobile(ss);
489 mobile::Module bc = _load_for_mobile(ss);
490 auto res =
491 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
492 const auto& str = res.toStringRef();
493 std::string expected = "Hello! Your tensor has 12 elements!";
494 AT_ASSERT(str == expected);
495 }
496
TEST(LiteInterpreterTest,BuiltinFunction)497 TEST(LiteInterpreterTest, BuiltinFunction) {
498 script::Module m("m");
499 auto custom_class_obj =
500 make_custom_class<TorchBindLiteInterpreterTestStruct>();
501 m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
502 m.define(R"(
503 def forward(self, x) -> str:
504 return self.my_obj.get(x)
505 )");
506
507 std::stringstream ss;
508 m._save_for_mobile(ss);
509 mobile::Module bc = _load_for_mobile(ss);
510 auto res =
511 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
512 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
513 auto str = res.toStringRef();
514 std::string expected = "Hello! Your tensor has 12 elements!";
515 AT_ASSERT(str == expected);
516 }
517
518 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest,GetRuntimeByteCodeVersion)519 TEST(LiteInterpreterTest, GetRuntimeByteCodeVersion) {
520 auto runtime_bytecode_version = _get_runtime_bytecode_version();
521 AT_ASSERT(
522 runtime_bytecode_version ==
523 caffe2::serialize::kMaxSupportedBytecodeVersion);
524 }
525
TEST(LiteInterpreterTest,GetRuntimeOperatorsVersion)526 TEST(LiteInterpreterTest, GetRuntimeOperatorsVersion) {
527 auto runtime_operators_version = _get_runtime_operators_min_max_versions();
528 AT_ASSERT(
529 runtime_operators_version.first ==
530 caffe2::serialize::kMinSupportedFileFormatVersion &&
531 runtime_operators_version.second ==
532 caffe2::serialize::kMaxSupportedFileFormatVersion);
533 }
534
535 /**
536 * The test below is disarmed for FB internal xplat builds since
537 * BUCK requires us to pass in the script_module_v4.ptl file in
538 * as a resource dependency of the build rule for this file, and
539 * we would need to access it via the C++ Resources API instead
540 * of directly reading from disk (which is what the open source
541 * build/run does).
542 */
TEST(LiteInterpreterTest,GetByteCodeVersion)543 TEST(LiteInterpreterTest, GetByteCodeVersion) {
544 std::string filePath(__FILE__);
545 auto test_model_file_v4 =
546 filePath.substr(0, filePath.find_last_of("/\\") + 1);
547 test_model_file_v4.append("script_module_v4.ptl");
548
549 auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
550 AT_ASSERT(version_v4 == 4);
551 }
552
553 #endif // !defined(FB_XPLAT_BUILD)
554
TEST(LiteInterpreterTest,GetContainTypes)555 TEST(LiteInterpreterTest, GetContainTypes) {
556 Module m("m");
557 m.define(R"(
558 def forward(self):
559 return 3
560 )");
561
562 std::stringstream ss;
563 m._save_for_mobile(ss, {}, true);
564
565 _get_mobile_model_contained_types(ss);
566 }
567
568 namespace {
569
compareModelOutput(c10::ArrayRef<IValue> actual_result_list,const std::vector<IValue> & expect_result_list)570 void compareModelOutput(
571 c10::ArrayRef<IValue> actual_result_list,
572 const std::vector<IValue>& expect_result_list) {
573 AT_ASSERT(actual_result_list.size() == expect_result_list.size());
574 AT_ASSERT(
575 actual_result_list[0].toTensor().equal(expect_result_list[0].toTensor()));
576 AT_ASSERT(
577 actual_result_list[1].toTensor().dim() ==
578 expect_result_list[1].toTensor().dim());
579 AT_ASSERT(
580 actual_result_list[2].toTensor().equal(expect_result_list[2].toTensor()));
581 AT_ASSERT(
582 actual_result_list[3].toTensor().equal(expect_result_list[3].toTensor()));
583 ASSERT_EQ(
584 actual_result_list[4].toStringRef(), expect_result_list[4].toStringRef());
585 ASSERT_EQ(actual_result_list[5].toBool(), expect_result_list[5].toBool());
586 ASSERT_EQ(actual_result_list[6].toBool(), expect_result_list[6].toBool());
587 ASSERT_EQ(actual_result_list[7].toBool(), expect_result_list[7].toBool());
588 AT_ASSERT(
589 actual_result_list[8].toTensor().equal(expect_result_list[8].toTensor()));
590 ASSERT_EQ(
591 actual_result_list[9].toStringRef(), expect_result_list[9].toStringRef());
592 ASSERT_EQ(actual_result_list[10].toInt(), expect_result_list[10].toInt());
593 ASSERT_EQ(actual_result_list[11].toBool(), expect_result_list[11].toBool());
594 }
595
runAndCheckTorchScriptModel(std::stringstream & input_model_stream,const std::vector<IValue> & input_data,const std::vector<IValue> & expect_result_list,const uint64_t expect_version)596 void runAndCheckTorchScriptModel(
597 std::stringstream& input_model_stream,
598 const std::vector<IValue>& input_data,
599 const std::vector<IValue>& expect_result_list,
600 const uint64_t expect_version) {
601 auto actual_version = _get_model_bytecode_version(input_model_stream);
602 AT_ASSERT(actual_version == expect_version);
603
604 // Load and run the backport model, then compare the result with expect
605 // result
606 Module m_mobile = load(input_model_stream);
607
608 auto actual_result = m_mobile.forward(input_data);
609 const auto& actual_result_list = actual_result.toTupleRef().elements();
610 compareModelOutput(actual_result_list, expect_result_list);
611 }
612
runAndCheckBytecodeModel(std::stringstream & input_model_stream,const std::vector<IValue> & input_data,const std::vector<IValue> & expect_result_list,const uint64_t expect_version)613 void runAndCheckBytecodeModel(
614 std::stringstream& input_model_stream,
615 const std::vector<IValue>& input_data,
616 const std::vector<IValue>& expect_result_list,
617 const uint64_t expect_version) {
618 auto actual_version = _get_model_bytecode_version(input_model_stream);
619 AT_ASSERT(actual_version == expect_version);
620
621 // Load and run the backport model, then compare the result with expect
622 // result
623 Module m_mobile = load(input_model_stream);
624
625 auto actual_result = m_mobile.forward(input_data);
626 const auto& actual_result_list = actual_result.toTupleRef().elements();
627
628 compareModelOutput(actual_result_list, expect_result_list);
629 }
630
backportAllVersionCheck(std::stringstream & test_model_file_stream,std::vector<IValue> & input_data,std::vector<IValue> & expect_result_list,const uint64_t expect_from_version)631 void backportAllVersionCheck(
632 std::stringstream& test_model_file_stream,
633 std::vector<IValue>& input_data,
634 std::vector<IValue>& expect_result_list,
635 const uint64_t expect_from_version) {
636 auto from_version = _get_model_bytecode_version(test_model_file_stream);
637 EXPECT_EQ(from_version, expect_from_version);
638 AT_ASSERT(from_version > 0);
639
640 // Backport script_module_v5.ptl to an older version
641 constexpr int64_t minimum_to_version = 4;
642 auto current_to_version = from_version - 1;
643
644 // Verify all candidate to_version work as expected. All backport to version
645 // larger than minimum_to_version should success.
646 while (current_to_version >= minimum_to_version) {
647 // Do not declare std::stringstream oss outside of the while loop as
648 // oss.clear() doesn't reset the stream content, only clears out error state
649 // flag in stringstream causing a problematic stream. Instead, it's cleaner
650 // and safer to just declare a new std::stringstream one and swap them.
651 std::stringstream oss;
652 bool backPortSuccess =
653 _backport_for_mobile(test_model_file_stream, oss, current_to_version);
654 AT_ASSERT(backPortSuccess);
655
656 // Check backport model version
657 auto backport_version = _get_model_bytecode_version(oss);
658 backport_version = _get_model_bytecode_version(oss);
659 AT_ASSERT(backport_version == current_to_version);
660
661 // Load and run the backport model, then compare the result with expect
662 // result
663 runAndCheckBytecodeModel(
664 oss, input_data, expect_result_list, current_to_version);
665 oss.seekg(0, oss.beg);
666 runAndCheckTorchScriptModel(
667 oss, input_data, expect_result_list, current_to_version);
668
669 current_to_version--;
670 }
671 // backport to minimum version - 1 should fail
672 std::stringstream oss;
673 bool backPortSuccess =
674 _backport_for_mobile(test_model_file_stream, oss, minimum_to_version - 1);
675 AT_ASSERT(!backPortSuccess);
676 }
677 } // namespace
678
679 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest,BackPortByteCodeModelAllVersions)680 TEST(LiteInterpreterTest, BackPortByteCodeModelAllVersions) {
681 torch::jit::Module module("m");
682 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
683 module.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
684 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
685 module.register_parameter("bias", torch::ones({20}), false);
686 module.define(R"(
687 def fn(self, x:float=1.0):
688 return x
689
690 def forward(self, input):
691 x1 = torch.zeros(2, 2)
692 x2 = torch.empty_like(torch.empty(2, 2))
693 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
694 # Add torch.add operator to cover bytecode version bump from 6 to 7
695 # for bytecode version 7, the main change is to support defaults arguments with out arguments
696 x = 2 * torch.ones(1)
697 h = torch.ones(1)
698 torch.add(x, h, out=x)
699 device = torch.ones(1, 1).cpu().device.type
700 is_cuda = x1.is_cuda
701 bool_val = True
702 check_is = [] is None
703 check_is_not = [1] is not None
704 check_not = not bool_val
705 num_to_tensor = torch.tensor([self.fn()])
706 d = {"a": "abc"}
707 check_dict_index = d["a"]
708 check_dim = x1.dim()
709 return (
710 x1, x2, x3, x, device, is_cuda, check_is,
711 check_is_not, num_to_tensor, check_dict_index,
712 check_dim, check_not
713 )
714 )");
715
716 torch::jit::Module module_freeze = freeze(module);
717
718 std::stringstream input_model_stream;
719 module_freeze._save_for_mobile(
720 input_model_stream,
721 /*extra_files=*/{},
722 /*save_mobile_debug_info=*/false,
723 /*use_flatbuffer=*/true);
724 std::vector<IValue> input_data =
725 std::vector<IValue>({torch::ones({1, 1, 28, 28})});
726 std::vector<IValue> expect_result_list;
727 expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float) * 0);
728 expect_result_list.emplace_back(at::ones({2, 2}, ScalarType::Float));
729 expect_result_list.emplace_back(
730 at::ones({1, 20, 24, 24}, ScalarType::Float) * 26);
731 expect_result_list.emplace_back(3 * at::ones({1}));
732 // "cpu" False, False, True, tensor(1), "abc", 2, False)
733 expect_result_list.emplace_back(c10::IValue("cpu"));
734 expect_result_list.emplace_back(c10::IValue(false));
735 expect_result_list.emplace_back(c10::IValue(false));
736 expect_result_list.emplace_back(c10::IValue(true));
737 expect_result_list.emplace_back(c10::IValue(at::ones({1})));
738 expect_result_list.emplace_back(c10::IValue("abc"));
739 expect_result_list.emplace_back(c10::IValue(2));
740 expect_result_list.emplace_back(c10::IValue(false));
741
742 backportAllVersionCheck(
743 input_model_stream,
744 input_data,
745 expect_result_list,
746 9); // flatbuffer starts at 9
747 }
748 #endif // !defined(FB_XPLAT_BUILD)
749
TEST(LiteInterpreterTest,GetRuntimeOpsAndInfo)750 TEST(LiteInterpreterTest, GetRuntimeOpsAndInfo) {
751 auto runtime_ops = _get_runtime_ops_and_info();
752 // Ballpark estimate of the minimal number of ops; just used to
753 // verify API returns a reasonably large number.
754 AT_ASSERT(runtime_ops.size() > 2900);
755 }
756
TEST(LiteInterpreterTest,isCompatibleSuccess)757 TEST(LiteInterpreterTest, isCompatibleSuccess) {
758 // test trivial success case
759 auto runtime_info = RuntimeCompatibilityInfo::get();
760 std::unordered_map<std::string, OperatorInfo> model_ops;
761 model_ops["aten::add.Scalar"] = OperatorInfo{2};
762
763 std::unordered_set<std::string> types = {"List", "int", "NamedTuple"};
764 auto model_info = ModelCompatibilityInfo{
765 caffe2::serialize::kMaxSupportedBytecodeVersion,
766 model_ops,
767 types,
768 _get_runtime_bytecode_min_max_versions().first};
769
770 AT_ASSERT(
771 is_compatible(runtime_info, model_info).status ==
772 ModelCompatibilityStatus::OK);
773 }
774
TEST(LiteInterpreterTest,isCompatibleFail)775 TEST(LiteInterpreterTest, isCompatibleFail) {
776 // test trivial failure due to ops
777 std::unordered_map<std::string, OperatorInfo> model_ops;
778 model_ops["aten::add.Scalar"] = OperatorInfo{2};
779 auto model_info = ModelCompatibilityInfo{
780 caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops};
781 std::unordered_map<std::string, OperatorInfo> runtime_ops;
782 runtime_ops["aten::add.Int"] = OperatorInfo{2};
783 auto runtime_info = RuntimeCompatibilityInfo{
784 std::pair<uint64_t, uint64_t>(
785 caffe2::serialize::kMinSupportedBytecodeVersion,
786 caffe2::serialize::kMaxSupportedBytecodeVersion),
787 runtime_ops,
788 _get_mobile_supported_types()};
789
790 auto result = is_compatible(runtime_info, model_info);
791 AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
792 AT_ASSERT(
793 result.errors[0] ==
794 "Operator 'aten::add.Scalar' missing from runtime (not found)");
795
796 // test trivial failure due to bytecode greater than max supported bytecode
797 // version
798 runtime_ops["aten::add.Scalar"] = OperatorInfo{2};
799 runtime_info = RuntimeCompatibilityInfo{
800 std::pair<uint64_t, uint64_t>(
801 caffe2::serialize::kMinSupportedBytecodeVersion,
802 caffe2::serialize::kMaxSupportedBytecodeVersion),
803 runtime_ops,
804 _get_mobile_supported_types()};
805 model_info.bytecode_version =
806 caffe2::serialize::kMaxSupportedBytecodeVersion + 1;
807
808 result = is_compatible(runtime_info, model_info);
809 AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
810
811 // test trivial failure due to bytecode less than min supported bytecode
812 // version
813 runtime_ops["aten::add.Scalar"] = OperatorInfo{2};
814 runtime_info = RuntimeCompatibilityInfo{
815 std::pair<uint64_t, uint64_t>(
816 caffe2::serialize::kMinSupportedBytecodeVersion,
817 caffe2::serialize::kMaxSupportedBytecodeVersion),
818 runtime_ops,
819 _get_mobile_supported_types()};
820 model_info.bytecode_version =
821 caffe2::serialize::kMinSupportedBytecodeVersion - 1;
822
823 result = is_compatible(runtime_info, model_info);
824 AT_ASSERT(result.status = ModelCompatibilityStatus::ERROR);
825
826 // test trivial failure due to type
827 runtime_info = RuntimeCompatibilityInfo::get();
828 std::unordered_set<std::string> types = {"List", "int", "Sequence"};
829
830 model_info = ModelCompatibilityInfo{
831 caffe2::serialize::kMaxSupportedBytecodeVersion,
832 model_ops,
833 types,
834 _get_runtime_bytecode_min_max_versions().first};
835
836 AT_ASSERT(
837 is_compatible(runtime_info, model_info).status ==
838 ModelCompatibilityStatus::ERROR);
839
840 // test trivial failure due to operator version
841 runtime_info = RuntimeCompatibilityInfo::get();
842
843 model_info = ModelCompatibilityInfo{
844 caffe2::serialize::kMaxSupportedBytecodeVersion, model_ops, {}, 0};
845
846 AT_ASSERT(
847 is_compatible(runtime_info, model_info).status ==
848 ModelCompatibilityStatus::ERROR);
849 }
850
TEST(LiteInterpreterTest,Eval)851 TEST(LiteInterpreterTest, Eval) {
852 std::vector<torch::jit::IValue> inputs;
853
854 Module m("m");
855 m.define(R"(
856 def __init__(self, x):
857 self.training = True
858
859 def forward(self, input):
860 return torch.dropout(input, 1.0, self.training)
861 )");
862
863 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
864 inputs.push_back(torch::ones({1, 1, 28, 28}));
865 m.eval();
866 auto outputref = m.forward(inputs).toTensor();
867
868 // save m in training mode to make sure that mobile eval() will correctly
869 // change back to eval mode
870 m.train();
871 std::stringstream ss;
872 m._save_for_mobile(ss);
873 mobile::Module bc = _load_for_mobile(ss);
874 bc.eval();
875 IValue res;
876 for (int i = 0; i < 3; ++i) {
877 res = bc.get_method("forward")(inputs);
878 }
879 auto output = res.toTensor();
880 AT_ASSERT(outputref.dim() == output.dim());
881 AT_ASSERT(
882 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
883 }
884
TEST(LiteInterpreterTest,FindWrongMethodName)885 TEST(LiteInterpreterTest, FindWrongMethodName) {
886 Module m("m");
887 m.register_parameter("foo", torch::ones({}), false);
888 m.define(R"(
889 def add(self, x):
890 b = 4
891 return self.foo + x + b
892 )");
893 std::stringstream ss;
894 m._save_for_mobile(ss);
895 mobile::Module bc = _load_for_mobile(ss);
896 ASSERT_TRUE(bc.find_method("forward") == std::nullopt);
897 }
898
TEST(LiteInterpreterTest,FindAndRunMethod)899 TEST(LiteInterpreterTest, FindAndRunMethod) {
900 Module m("m");
901 m.register_parameter("foo", torch::ones({}), false);
902 m.define(R"(
903 def add_it(self, x):
904 b = 4
905 return self.foo + x + b
906 )");
907
908 std::vector<IValue> inputs;
909 auto minput = 5 * torch::ones({});
910 inputs.emplace_back(minput);
911 auto ref = m.get_method("add_it")(inputs);
912
913 std::stringstream ss;
914 m._save_for_mobile(ss);
915 mobile::Module bc = _load_for_mobile(ss);
916 IValue res;
917 for (int i = 0; i < 3; ++i) {
918 auto bcinputs = inputs;
919 auto method = bc.find_method("add_it");
920 AT_ASSERT(method != std::nullopt);
921 res = (*method)(std::move(bcinputs));
922 }
923
924 auto resd = res.toTensor().item<float>();
925 auto refd = ref.toTensor().item<float>();
926 AT_ASSERT(resd == refd);
927 }
928
TEST(LiteInterpreterTest,RunMethodVariadic)929 TEST(LiteInterpreterTest, RunMethodVariadic) {
930 Module m("m");
931 m.register_parameter("foo", torch::ones({}), false);
932 m.define(R"(
933 def add_three(self, x, y):
934 return self.foo + x + y
935 )");
936
937 std::vector<IValue> inputs;
938 auto inputx = 5 * torch::ones({});
939 auto inputy = 4 * torch::ones({});
940 auto ref = m.run_method("add_three", inputx, inputy);
941
942 std::stringstream ss;
943 m._save_for_mobile(ss);
944 mobile::Module bc = _load_for_mobile(ss);
945 IValue res = bc.run_method("add_three", inputx, inputy);
946
947 auto resd = res.toTensor().item<float>();
948 auto refd = ref.toTensor().item<float>();
949 AT_ASSERT(resd == refd);
950 }
951
TEST(LiteInterpreterTest,DuplicateSetState)952 TEST(LiteInterpreterTest, DuplicateSetState) {
953 Module m("M");
954 m.register_parameter("foo", torch::ones({}), false);
955 m.define(R"(
956 def __getstate__(self):
957 return self.foo + self.foo
958 def __setstate__(self, a):
959 self.foo = a
960 def forward(self, x):
961 b = 4
962 return self.foo + x + b
963 )");
964
965 Module b("B");
966 b.register_module("M0", m);
967 b.register_module("M1", m);
968 b.define(R"(
969 def forward(self, x):
970 return self.M0.forward(x) + self.M1.forward(x)
971 )");
972
973 std::stringstream ss;
974 m._save_for_mobile(ss);
975 mobile::Module bc = _load_for_mobile(ss);
976 const auto methods = bc.get_methods();
977 const size_t expected_n = 3;
978 ASSERT_EQ(methods.size(), expected_n);
979 }
980
TEST(LiteInterpreterTest,ExtraFiles)981 TEST(LiteInterpreterTest, ExtraFiles) {
982 const auto script = R"JIT(
983 def forward(self):
984 x = torch.rand(5, 5)
985 x = x.mm(x)
986 return x
987 )JIT";
988
989 auto module =
990 std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
991 module->define(script);
992 std::ostringstream oss;
993 std::unordered_map<std::string, std::string> extra_files;
994 extra_files["metadata.json"] = "abc";
995 extra_files["mobile_info.json"] = "{\"key\": 23}";
996 module->_save_for_mobile(oss, extra_files);
997
998 std::istringstream iss(oss.str());
999 std::unordered_map<std::string, std::string> loaded_extra_files;
1000 loaded_extra_files["metadata.json"] = "";
1001 torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
1002 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
1003
1004 loaded_extra_files.clear();
1005 std::vector<std::string> all_files =
1006 caffe2::serialize::PyTorchStreamReader(&iss).getAllRecords();
1007
1008 for (auto& file_name : all_files) {
1009 if (file_name.find("extra/") == 0) {
1010 loaded_extra_files[file_name.substr(6)] = "";
1011 }
1012 }
1013 iss.seekg(0, iss.beg);
1014 torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
1015 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
1016 ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
1017
1018 std::unordered_map<std::string, std::string>
1019 loaded_extra_files_without_explicit_mapping;
1020 iss.seekg(0, iss.beg);
1021 torch::jit::_load_for_mobile(
1022 iss,
1023 torch::kCPU,
1024 loaded_extra_files_without_explicit_mapping,
1025 MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS);
1026 ASSERT_EQ(
1027 loaded_extra_files_without_explicit_mapping["metadata.json"], "abc");
1028 ASSERT_EQ(
1029 loaded_extra_files_without_explicit_mapping["mobile_info.json"],
1030 "{\"key\": 23}");
1031 }
1032
TEST(LiteInterpreterTest,OpNameExportFetchRootOperators)1033 TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
1034 torch::jit::Module m("m");
1035 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
1036 m.register_parameter("bias", torch::ones({20}), false);
1037 m.define(R"(
1038 def forward(self, input):
1039 x1 = torch.zeros(2, 2)
1040 x2 = torch.empty_like(torch.empty(2, 2))
1041 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
1042 return (x1, x2, x3)
1043 )");
1044 m.eval();
1045
1046 std::stringstream ss;
1047 m._save_for_mobile(ss);
1048
1049 torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss);
1050 std::set<std::string> operator_names =
1051 torch::jit::mobile::_export_operator_list(ptl_model);
1052 std::set<std::string> expected_operator_names = {
1053 "aten::_convolution",
1054 "aten::empty.memory_format",
1055 "aten::empty_like",
1056 "aten::zeros",
1057 };
1058 EXPECT_EQ(operator_names, expected_operator_names)
1059 << "Expected the root operator lists to be the same";
1060 }
1061
TEST(LiteInterpreterTest,DefaultArgsConv)1062 TEST(LiteInterpreterTest, DefaultArgsConv) {
1063 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
1064 if (s && strcmp(s, "1") == 0)
1065 return;
1066
1067 std::vector<torch::jit::IValue> inputs;
1068
1069 Module m("m");
1070 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
1071 m.register_parameter("bias", torch::ones({20}), false);
1072 m.define(R"(
1073 def forward(self, input):
1074 return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
1075 )");
1076
1077 inputs.push_back(torch::ones({1, 1, 28, 28}));
1078
1079 auto outputref = m.forward(inputs).toTensor();
1080
1081 std::stringstream ss;
1082 m._save_for_mobile(ss);
1083 mobile::Module bc = _load_for_mobile(ss);
1084 IValue res;
1085 for (int i = 0; i < 1; ++i) {
1086 res = bc.get_method("forward")(inputs);
1087 }
1088 auto output = res.toTensor();
1089 AT_ASSERT(outputref.dim() == output.dim());
1090 AT_ASSERT(output.equal(outputref));
1091 }
1092
TEST(RunTimeTest,ParseBytecode)1093 TEST(RunTimeTest, ParseBytecode) {
1094 // A simple example to show a simple bytecode that can be used independent of
1095 // PyTorch TorchScript serialization (unpickler, etc) and operator library.
1096 // It has basic control flow (if, else) and basic data orchestration (list
1097 // construction). The original PyTorch program:
1098
1099 // class Module(torch.nn.Module):
1100 //
1101 // def __init__(self) -> None:
1102 // super().__init__()
1103 //
1104 // def forward(self, x: int, h: int, xfirst: bool):
1105 // if xfirst:
1106 // return [x, h]
1107 // else:
1108 // return [h, x]
1109
1110 // 1. Prepare for the bytecode. In reality it can be from a customized
1111 // deserializer.
1112 std::vector<IValue> instructions{
1113 to_tuple({"STOREN", 1, 4}),
1114 to_tuple({"DROPR", 1, 0}),
1115 to_tuple({"MOVE", 4, 0}),
1116 to_tuple({"JF", 5, 0}),
1117 to_tuple({"LOAD", 2, 0}),
1118 to_tuple({"LOAD", 3, 0}),
1119 to_tuple({"LIST_CONSTRUCT", 0, 2}),
1120 to_tuple({"JMP", 4, 0}),
1121 to_tuple({"LOAD", 3, 0}),
1122 to_tuple({"LOAD", 2, 0}),
1123 to_tuple({"LIST_CONSTRUCT", 1, 2}),
1124 to_tuple({"STORE", 5, 0}),
1125 to_tuple({"DROPR", 3, 0}),
1126 to_tuple({"DROPR", 2, 0}),
1127 to_tuple({"MOVE", 5, 0}),
1128 to_tuple({"RET", 0, 0}),
1129 };
1130 std::vector<IValue> operators; // empty for this example
1131 std::vector<IValue> constants; // empty for this example
1132
1133 std::vector<IValue> types{"List[int]", "List[int]"};
1134 // 2. Parse the function
1135 std::string function_name("test_function");
1136 auto function = std::unique_ptr<mobile::Function>(
1137 new mobile::Function(c10::QualifiedName(function_name)));
1138 c10::ivalue::TupleElements debug_handles_m_tuple;
1139 parseInstructions(
1140 function_name,
1141 std::move(*c10::ivalue::Tuple::create(instructions)).elements(),
1142 debug_handles_m_tuple,
1143 function.get());
1144 parseTypes(c10::ivalue::Tuple::create(types)->elements(), function.get());
1145 const size_t rsize = 5;
1146 parseRegisterSize(rsize, function.get());
1147
1148 // 3. Prepare for inputs and run the function
1149 // Note that the first input is reserved for Module object.
1150 // Since this is a function test and Module object is not required,
1151 // a dummy IValue (0) is added here.
1152 std::vector<IValue> inputs{0, 1, 2, true};
1153 function->run(inputs);
1154 auto output = inputs[0].toList();
1155 ASSERT_EQ(output[0], 1);
1156 ASSERT_EQ(output[1], 2);
1157
1158 std::vector<IValue> inputs1{0, 1, 2, false};
1159 function->run(inputs1);
1160 auto output1 = inputs1[0].toList();
1161 ASSERT_EQ(output1[0], 2);
1162 ASSERT_EQ(output1[1], 1);
1163 }
1164
TEST(RunTimeTest,ParseOperator)1165 TEST(RunTimeTest, ParseOperator) {
1166 // A simple example to show a simple bytecode that can be used independent of
1167 // PyTorch TorchScript serialization (unpickler, etc) and operator library.
1168 // It has one operator and we should be able to register it. The original
1169 // PyTorch program:
1170
1171 // class Add(torch.nn.Module):
1172 // def __init__(self) -> None:
1173 // super().__init__()
1174
1175 // def forward(self, a, b):
1176 // return a + b
1177
1178 // 1. Prepare for the bytecode. In reality it can be from a customized
1179 // deserializer.
1180 std::vector<IValue> instructions{
1181 to_tuple({"STOREN", 1, 3}),
1182 to_tuple({"DROPR", 1, 0}),
1183 to_tuple({"MOVE", 2, 0}),
1184 to_tuple({"MOVE", 3, 0}),
1185 to_tuple({"OP", 0, 0}),
1186 to_tuple({"RET", 0, 0}),
1187 };
1188 std::vector<IValue> operators{
1189 to_tuple({"aten::add", "Tensor", 2}),
1190 };
1191 std::vector<IValue> constants{
1192 to_tuple({1}),
1193 };
1194 // 2. Parse the function
1195 std::string function_name("test_function");
1196 auto function = std::unique_ptr<mobile::Function>(
1197 new mobile::Function(c10::QualifiedName(function_name)));
1198 c10::ivalue::TupleElements debug_handles_m_tuple;
1199 parseInstructions(
1200 function_name,
1201 std::move(*c10::ivalue::Tuple::create(instructions)).elements(),
1202 debug_handles_m_tuple,
1203 function.get());
1204 parseOperators(
1205 std::move(*c10::ivalue::Tuple::create(operators)).elements(),
1206 1,
1207 function.get());
1208 const size_t rsize = 5;
1209 parseRegisterSize(rsize, function.get());
1210
1211 // 3. Prepare for inputs and run the function
1212 // Note that the first input is reserved for Module object.
1213 // Since this is a function test and Module object is not required,
1214 // a dummy IValue (0) is added here.
1215 std::vector<IValue> inputs{0, at::tensor(1), at::tensor(2)};
1216 function->run(inputs);
1217 auto output = inputs[0];
1218 ASSERT_EQ(output, at::tensor(3));
1219 }
1220
1221 namespace {
testLiteModuleCompareResultTensors(Module & m,const std::vector<torch::jit::IValue> & inputs,const std::string & method_name="forward")1222 void testLiteModuleCompareResultTensors(
1223 Module& m,
1224 const std::vector<torch::jit::IValue>& inputs,
1225 const std::string& method_name = "forward") {
1226 auto outputref = m.get_method(method_name)(inputs).toTensor();
1227
1228 std::stringstream ss;
1229 m._save_for_mobile(ss);
1230 mobile::Module bc = _load_for_mobile(ss);
1231 IValue res;
1232 for (int i = 0; i < 3; ++i) {
1233 res = bc.get_method(method_name)(inputs);
1234 }
1235 auto output = res.toTensor();
1236 AT_ASSERT(outputref.dim() == output.dim());
1237 AT_ASSERT(output.equal(outputref));
1238 }
1239
testDefaultArgsPinv(int num_args)1240 void testDefaultArgsPinv(int num_args) {
1241 Module m("m");
1242 if (num_args == 1) {
1243 m.define(R"(
1244 def forward(self, input):
1245 return torch.linalg_pinv(input)
1246 )");
1247 } else if (num_args == 2) {
1248 m.define(R"(
1249 def forward(self, input):
1250 return torch.linalg_pinv(input, 1e-5)
1251 )");
1252 } else if (num_args == 3) {
1253 m.define(R"(
1254 def forward(self, input):
1255 return torch.linalg_pinv(input, 1e-5, True)
1256 )");
1257 }
1258
1259 std::vector<torch::jit::IValue> inputs;
1260 const int N = 28;
1261 auto input = torch::range(1, N * N, 1);
1262 input[0] = 1; // a more stable matrix
1263 input = input.view({N, N});
1264 inputs.push_back(input);
1265 testLiteModuleCompareResultTensors(m, inputs);
1266 }
1267 } // namespace
1268
1269 #if !defined FB_XPLAT_BUILD
TEST(LiteInterpreterTest,DefaultArgsPinv)1270 TEST(LiteInterpreterTest, DefaultArgsPinv) {
1271 // Test with different number of specified arguments.
1272 // Arguments not specified take default value.
1273 for (int num_args = 1; num_args <= 3; ++num_args) {
1274 testDefaultArgsPinv(num_args);
1275 }
1276
1277 // bytecode with one specified argument:
1278 // (6,
1279 // ('__torch__.m.forward',
1280 // (('instructions',
1281 // (('STOREN', 1, 2),
1282 // ('DROPR', 1, 0),
1283 // ('MOVE', 2, 0),
1284 // ('OP', 0, 0),
1285 // ('RET', 0, 0))),
1286 // ('operators', (('aten::linalg_pinv', '', 1),)),
1287 // ('constants', (False, 1e-15)), # default constants are not
1288 // used
1289 // ('types', ()),
1290 // ('register_size', 2)),
1291 // (('arguments',
1292 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1293 // None)),
1294 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1295 // None)))),
1296 // ('returns',
1297 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1298 // None)),)))))
1299
1300 // bytecode with 2 specified argument:
1301 // (6,
1302 // ('__torch__.m.forward',
1303 // (('instructions',
1304 // (('STOREN', 1, 2),
1305 // ('DROPR', 1, 0),
1306 // ('MOVE', 2, 0),
1307 // ('LOADC', 1, 0), # added LOADC for specified argument
1308 // ('OP', 0, 0),
1309 // ('RET', 0, 0))),
1310 // ('operators', (('aten::linalg_pinv', '', 2),)),
1311 // ('constants', (False, 1e-05)), # updated constant table
1312 // ('types', ()),
1313 // ('register_size', 2)),
1314 // (('arguments',
1315 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1316 // None)),
1317 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1318 // None)))),
1319 // ('returns',
1320 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1321 // None)),)))))
1322
1323 // bytecode with 3 specified arguments:
1324 // (6,
1325 // ('__torch__.m.forward',
1326 // (('instructions',
1327 // (('STOREN', 1, 2),
1328 // ('DROPR', 1, 0),
1329 // ('MOVE', 2, 0),
1330 // ('LOADC', 1, 0),
1331 // ('LOADC', 0, 0),
1332 // ('OP', 0, 0),
1333 // ('RET', 0, 0))),
1334 // ('operators', (('aten::linalg_pinv', '', 3),)),
1335 // ('constants', (True, 1e-05)),
1336 // ('types', ()),
1337 // ('register_size', 2)),
1338 // (('arguments',
1339 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1340 // None)),
1341 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1342 // None)))),
1343 // ('returns',
1344 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1345 // None)),)))))
1346 }
1347
TEST(LiteInterpreterTest,DefaultArgsTensorinvSpecifyDefault)1348 TEST(LiteInterpreterTest, DefaultArgsTensorinvSpecifyDefault) {
1349 // The second argument is specified, but the value is the same as the default
1350 // value. It's treated as "not specified" since the value can be fetched from
1351 // schema.
1352 Module m("m");
1353 m.define(R"(
1354 def forward(self, input):
1355 return torch.linalg_tensorinv(input, 2)
1356 )");
1357 torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
1358 auto arg_nums = code.op_to_num_specified_args();
1359 ASSERT_EQ(arg_nums.size(), 1);
1360 ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
1361 std::vector<torch::jit::IValue> inputs;
1362 const int N = 4;
1363 auto input = torch::rand({N, N, N, N});
1364 inputs.push_back(input);
1365 testLiteModuleCompareResultTensors(m, inputs);
1366 }
1367
testDefaultArgsPinvWithOutArg(int num_args)1368 void testDefaultArgsPinvWithOutArg(int num_args) {
1369 Module m("m");
1370 if (num_args == 1) {
1371 m.define(R"(
1372 def forward(self, input):
1373 return torch.linalg_pinv(input, out=input)
1374 )");
1375 } else if (num_args == 2) {
1376 m.define(R"(
1377 def forward(self, input):
1378 return torch.linalg_pinv(input, 1e-5, out=input)
1379 )");
1380 } else if (num_args == 3) {
1381 m.define(R"(
1382 def forward(self, input):
1383 return torch.linalg_pinv(input, 1e-5, True, out=input)
1384 )");
1385 }
1386
1387 const int N = 28;
1388 auto input = torch::range(1, N * N, 1);
1389 input[0] = 10000; // a more stable matrix
1390 input = input.view({N, N});
1391 auto ref = m.run_method("forward", input);
1392 TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
1393 TORCH_CHECK(input.equal(ref.toTensor()));
1394 }
1395
TEST(LiteInterpreterTest,DefaultArgsPinvWithOutArg)1396 TEST(LiteInterpreterTest, DefaultArgsPinvWithOutArg) {
1397 // Test with different number of specified arguments + out arg.
1398 // Arguments not specified take default value.
1399 for (int num_args = 1; num_args <= 3; ++num_args) {
1400 testDefaultArgsPinvWithOutArg(num_args);
1401 }
1402 }
1403
TEST(LiteInterpreterTest,DefaultArgsWithOutArg)1404 TEST(LiteInterpreterTest, DefaultArgsWithOutArg) {
1405 Module m("m");
1406 m.define(R"(
1407 def forward(self, x, h):
1408 torch.add(x, h, out=x)
1409 )");
1410
1411 std::vector<IValue> inputs;
1412 auto input_x = 2 * torch::ones({});
1413 auto input_h = torch::ones({});
1414 auto ref = m.run_method("forward", input_x, input_h);
1415
1416 std::stringstream ss;
1417
1418 m._save_for_mobile(ss, {}, true);
1419 mobile::Module bc = _load_for_mobile(ss);
1420 bc.run_method("forward", input_x, input_h);
1421 AT_ASSERT(input_x.equal(4 * torch::ones({})));
1422
1423 auto ops = _get_model_ops_and_info(ss);
1424 auto op = ops.find("aten::add.out");
1425 TORCH_CHECK(
1426 op != ops.end() && op->second.num_schema_args.has_value() &&
1427 op->second.num_schema_args.value() == 3);
1428 }
1429
TEST(LiteInterpreterTest,TestExceptionStackWithTwoLevelModuleHierarchy)1430 TEST(LiteInterpreterTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
1431 Module a("A");
1432 a.define(R"(
1433 def bar(self, x, y):
1434 return x + y
1435 )");
1436 Module b("B");
1437 b.register_module("A0", a);
1438 b.define(R"(
1439 def foo(self, x, y):
1440 return self.A0.bar(x, y) + 2
1441 )");
1442 Module c("C");
1443 c.register_module("B0", b);
1444 c.define(R"(
1445 def forward(self, x, y):
1446 return self.B0.foo(x, y) + 3
1447 )");
1448
1449 std::vector<IValue> inputs;
1450 inputs.emplace_back(torch::rand({2, 4}));
1451 inputs.emplace_back(torch::rand({13, 9}));
1452
1453 std::stringstream ss;
1454 c._save_for_mobile(ss, ExtraFilesMap(), true);
1455 auto lite_m = _load_for_mobile(ss);
1456 std::string error_pattern = R"(
1457 Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
1458 Traceback of TorchScript (most recent call last):
1459 File "<string>", line 3, in <unknown>
1460
1461 def forward(self, x, y):
1462 return self.B0.foo(x, y) + 3
1463 ~~~~~~~~~~~ <--- HERE
1464
1465 File "<string>", line 3, in foo
1466
1467 def foo(self, x, y):
1468 return self.A0.bar(x, y) + 2
1469 ~~~~~~~~~~~ <--- HERE
1470
1471 File "<string>", line 3, in bar
1472
1473 def bar(self, x, y):
1474 return x + y
1475 ~~~~~ <--- HERE
1476 )";
1477 ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
1478 }
1479 #endif // !defined(FB_XPLAT_BUILD)
1480
1481 namespace {
1482 static auto reg =
1483 torch::class_<TorchBindLiteInterpreterTestStruct>(
1484 "_TorchScriptTesting",
1485 "_LiteInterpreterTest")
1486 .def(torch::init<>())
1487 .def("get", &TorchBindLiteInterpreterTestStruct::get)
1488 .def_pickle(
1489 // __getattr__
1490 [](const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>&
__anond2aa1aa70502(const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>& self) 1491 self) -> int64_t { return 0; },
1492 // __setattr__
__anond2aa1aa70602(int64_t state) 1493 [](int64_t state) {
1494 return c10::make_intrusive<TorchBindLiteInterpreterTestStruct>();
1495 });
1496
1497 } // namespace
1498
TEST(LiteInterpreterTest,OperatorCacheDifferentiatesDefaultArgs)1499 TEST(LiteInterpreterTest, OperatorCacheDifferentiatesDefaultArgs) {
1500 // Create 3 methods:
1501 //
1502 // 1. forward() returns a tensor with dtype=torch.int64 (4)
1503 // 2. forward2() returns a tensor with dtype=torch.float32 (6)
1504 // 3. forward3() returns a tensor with dtype=torch.float32 but
1505 // the dtype is inferred by the input tensor's dtype
1506 //
1507 // If caching works correctly, then the result from the full-jit
1508 // module and the lite module will be the same. Otherwise, it
1509 // will be different if we don't correctly ignore the cache
1510 // entry for an operator that has a different number of
1511 // arguments.
1512 Module m("m");
1513 m.define(R"(
1514 def forward(self):
1515 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
1516 return ret1.fill_(25)
1517 )");
1518 m.define(R"(
1519 def forward2(self):
1520 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
1521 return ret1.fill_(32.0)
1522 )");
1523 m.define(R"(
1524 def forward3(self):
1525 ret1 = torch.new_empty(torch.zeros(10), [10])
1526 return ret1.fill_(12.0)
1527 )");
1528
1529 std::vector<torch::jit::IValue> inputs;
1530 testLiteModuleCompareResultTensors(m, inputs, "forward");
1531 testLiteModuleCompareResultTensors(m, inputs, "forward2");
1532 testLiteModuleCompareResultTensors(m, inputs, "forward3");
1533 }
1534
TEST(RunTimeTest,RuntimeCall)1535 TEST(RunTimeTest, RuntimeCall) {
1536 // def call(x):
1537 // return x + x
1538 //
1539 // def forward(a):
1540 // x = a + call(a)
1541 // y = a + call(x)
1542 // return y
1543
1544 std::vector<IValue> instructionsCall{
1545 to_tuple({"STORE", 1, 0}),
1546 to_tuple({"LOAD", 1, 0}),
1547 to_tuple({"MOVE", 1, 0}),
1548 to_tuple({"LOADC", 0, 0}),
1549 to_tuple({"OP", 0, 0}),
1550 to_tuple({"RET", 0, 0}),
1551 };
1552 std::vector<IValue> instructionsFoo{
1553 to_tuple({"STORE", 1, 0}),
1554 to_tuple({"LOAD", 1, 0}),
1555 to_tuple({"LOAD", 1, 0}),
1556 to_tuple({"MOVE", 1, 0}),
1557 to_tuple({"CALL", 0, 0}),
1558 to_tuple({"LOADC", 0, 0}),
1559 to_tuple({"OP", 0, 0}),
1560 to_tuple({"CALL", 0, 0}),
1561 to_tuple({"LOADC", 0, 0}),
1562 to_tuple({"OP", 0, 0}),
1563 to_tuple({"RET", 0, 0}),
1564 };
1565 std::vector<IValue> operatorsFoo{
1566 to_tuple({"aten::add", "Tensor", 3}),
1567 };
1568 std::vector<IValue> constantsFoo{
1569 1,
1570 };
1571 std::vector<IValue> operatorsCall{
1572 to_tuple({"aten::add", "Tensor", 3}),
1573 };
1574 std::vector<IValue> constantsCall{
1575 1,
1576 };
1577
1578 auto foo = std::make_unique<mobile::Function>(c10::QualifiedName("foo"));
1579 c10::ivalue::TupleElements debug_handles_m_tuple;
1580 parseInstructions(
1581 "foo",
1582 std::move(*c10::ivalue::Tuple::create(instructionsFoo)).elements(),
1583 debug_handles_m_tuple,
1584 foo.get());
1585 parseOperators(
1586 std::move(*c10::ivalue::Tuple::create(operatorsFoo)).elements(),
1587 1,
1588 foo.get());
1589 parseConstants(
1590 std::move(*c10::ivalue::Tuple::create(constantsFoo)).elements(),
1591 foo.get());
1592 const size_t rsize = 5;
1593 parseRegisterSize(rsize, foo.get());
1594
1595 auto call = std::make_unique<mobile::Function>(c10::QualifiedName("call"));
1596 parseInstructions(
1597 "call",
1598 std::move(*c10::ivalue::Tuple::create(instructionsCall)).elements(),
1599 debug_handles_m_tuple,
1600 call.get());
1601 parseOperators(
1602 std::move(*c10::ivalue::Tuple::create(operatorsCall)).elements(),
1603 1,
1604 call.get());
1605 parseConstants(
1606 std::move(*c10::ivalue::Tuple::create(constantsCall)).elements(),
1607 call.get());
1608 parseRegisterSize(rsize, call.get());
1609
1610 foo->append_function(*call);
1611
1612 std::vector<IValue> inputs{at::tensor(1)};
1613 foo->run(inputs);
1614 auto output = inputs[0];
1615 ASSERT_EQ(output, at::tensor(7));
1616 }
1617
TEST(LiteInterpreterTest,OperatorSize1)1618 TEST(LiteInterpreterTest, OperatorSize1) {
1619 Module m("m");
1620 m.define(R"(
1621 def forward(self, input: Tensor, scale:float):
1622 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1623 )");
1624
1625 std::stringstream ss;
1626 m._save_for_mobile(ss);
1627 mobile::Module bc = _load_for_mobile(ss);
1628 const auto& func = bc.get_method("forward").function();
1629 ASSERT_EQ(
1630 func.get_code().operator_input_sizes_.size(),
1631 func.get_code().operators_.size());
1632 }
1633
TEST(LiteInterpreterTest,OperatorTest2)1634 TEST(LiteInterpreterTest, OperatorTest2) { // NOLINT (use =delete in gtest)
1635 const std::vector<std::string> test_programs{
1636 // test invoking a method with default parameter
1637 R"(
1638 def test_func(self, x, b : int = 4):
1639 return self.foo + x + b
1640 )",
1641 // inner method call with default parameter (gets inlined)
1642 R"(
1643 def add_with_default_arg(self, x, b : int = 4):
1644 return self.foo + x + b
1645 def test_func(self, x):
1646 return self.add_with_default_arg(x) # invoke method w/ default arg
1647 )",
1648 // simple method call
1649 R"(
1650 def test_func(self, x):
1651 b = 4
1652 return self.foo + x + b
1653 )",
1654 };
1655 for (const auto& test_program : test_programs) {
1656 Module m("m");
1657 m.register_parameter("foo", torch::ones({}), false);
1658 m.define(test_program);
1659
1660 std::stringstream ss;
1661 m._save_for_mobile(ss);
1662 mobile::Module bc = _load_for_mobile(ss);
1663 const auto& func = bc.get_method("test_func").function();
1664 ASSERT_EQ(
1665 func.get_code().operator_input_sizes_.size(),
1666 func.get_code().operators_.size());
1667 }
1668 }
1669
1670 #if !defined FB_XPLAT_BUILD
1671 // The following test run in fbcode only
TEST(LiteInterpreterUpgraderTest,DivTensorV2)1672 TEST(LiteInterpreterUpgraderTest, DivTensorV2) {
1673 std::string filePath(__FILE__);
1674 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1675 test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl");
1676 /*
1677 (('__torch__.MyModule.forward',
1678 (('instructions',
1679 (('STOREN', 1, 3),
1680 ('DROPR', 1, 0),
1681 ('LOAD', 2, 0),
1682 ('LOAD', 3, 0),
1683 ('OP', 0, 0),
1684 ('LOAD', 2, 0),
1685 ('LOAD', 3, 0),
1686 ('OP', 1, 0),
1687 ('MOVE', 2, 0),
1688 ('MOVE', 3, 0),
1689 ('OP', 2, 0),
1690 ('TUPLE_CONSTRUCT', 3, 0),
1691 ('RET', 0, 0))),
1692 ('operators',
1693 (('aten::div', 'Tensor'),
1694 ('aten::div', 'Tensor'),
1695 ('aten::div', 'Tensor'))),
1696 ('constants', ()),
1697 ('types', ()),
1698 ('register_size', 3))),)
1699
1700 */
1701 mobile::Module m_module = _load_for_mobile(test_model_file);
1702 auto intrsuction_list =
1703 m_module.get_method("forward").function().get_code().instructions_;
1704 uint64_t number_of_call_instruction = 0;
1705 for (auto& instruction : intrsuction_list) {
1706 number_of_call_instruction += (instruction.op == OpCode::CALL);
1707 }
1708 // 3 operators will use upgrader
1709 ASSERT_EQ(number_of_call_instruction, 3);
1710
1711 std::vector<IValue> inputs = {
1712 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1713 auto actual_output = m_module.forward(inputs);
1714 auto expect_output = 2.0 * torch::ones({1});
1715 auto actual_output_list = actual_output.toTuple()->elements();
1716 ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output));
1717 }
1718
TEST(LiteInterpreterUpgraderTest,DivTensorOutV2)1719 TEST(LiteInterpreterUpgraderTest, DivTensorOutV2) {
1720 std::string filePath(__FILE__);
1721 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1722 test_model_file.append(
1723 "upgrader_models/test_versioned_div_tensor_out_v2.ptl");
1724 /*
1725 (('__torch__.MyModule.forward',
1726 (('instructions',
1727 (('STOREN', 1, 4),
1728 ('DROPR', 1, 0),
1729 ('MOVE', 2, 0),
1730 ('MOVE', 3, 0),
1731 ('MOVE', 4, 0),
1732 ('OP', 0, 0),
1733 ('RET', 0, 0))),
1734 ('operators', (('aten::div', 'out'),)),
1735 ('constants', ()),
1736 ('types', ()),
1737 ('register_size', 4))),)
1738 */
1739 mobile::Module m_module = _load_for_mobile(test_model_file);
1740
1741 auto intrsuction_list =
1742 m_module.get_method("forward").function().get_code().instructions_;
1743 uint64_t number_of_call_instruction = 0;
1744 for (auto& instruction : intrsuction_list) {
1745 number_of_call_instruction += (instruction.op == OpCode::CALL);
1746 }
1747 // One operator will use upgrader
1748 ASSERT_EQ(number_of_call_instruction, 1);
1749
1750 std::vector<IValue> inputs{
1751 IValue(6 * torch::ones({1})),
1752 IValue(3 * torch::ones({1})),
1753 IValue(torch::empty({1}))};
1754 m_module.forward(inputs);
1755 auto expect_output = 2.0 * torch::ones({1});
1756 auto actual_output = inputs[2].toTensor();
1757 // The out argument will be overwritten with the output
1758 ASSERT_TRUE(actual_output.equal(expect_output));
1759 }
1760
TEST(LiteInterpreterUpgraderTest,DivTensorInplaceV2)1761 TEST(LiteInterpreterUpgraderTest, DivTensorInplaceV2) {
1762 std::string filePath(__FILE__);
1763 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1764 test_model_file.append(
1765 "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl");
1766 /*
1767 (('__torch__.MyModule.forward',
1768 (('instructions',
1769 (('STOREN', 1, 3),
1770 ('DROPR', 1, 0),
1771 ('MOVE', 2, 0),
1772 ('MOVE', 3, 0),
1773 ('OP', 0, 0),
1774 ('RET', 0, 0))),
1775 ('operators', (('aten::div_', 'Tensor'),)),
1776 ('constants', ()),
1777 ('types', ()),
1778 ('register_size', 3))),)
1779 */
1780 mobile::Module m_module = _load_for_mobile(test_model_file);
1781
1782 auto intrsuction_list =
1783 m_module.get_method("forward").function().get_code().instructions_;
1784 uint64_t number_of_call_instruction = 0;
1785 for (auto& instruction : intrsuction_list) {
1786 number_of_call_instruction += (instruction.op == OpCode::CALL);
1787 }
1788 // One operator will use upgrader
1789 ASSERT_EQ(number_of_call_instruction, 1);
1790
1791 std::vector<IValue> inputs{
1792 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1793 m_module.forward(inputs);
1794 auto expect_output = 2.0 * torch::ones({1});
1795 auto actual_output = inputs[0].toTensor();
1796 // The out argument will be overwritten with the output
1797 ASSERT_TRUE(actual_output.equal(expect_output));
1798 }
1799
TEST(LiteInterpreterUpgraderTest,DivScalarFloatV2)1800 TEST(LiteInterpreterUpgraderTest, DivScalarFloatV2) {
1801 std::string filePath(__FILE__);
1802 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1803 test_model_file.append(
1804 "upgrader_models/test_versioned_div_scalar_float_v2.ptl");
1805 /*
1806 (('__torch__.MyModuleFloat.forward',
1807 (('instructions',
1808 (('STOREN', 1, 3),
1809 ('DROPR', 1, 0),
1810 ('MOVE', 2, 0),
1811 ('MOVE', 3, 0),
1812 ('OP', 0, 0),
1813 ('RET', 0, 0))),
1814 ('operators', (('aten::div', 'Scalar'),)),
1815 ('constants', ()),
1816 ('types', ()),
1817 ('register_size', 3))),)
1818 */
1819
1820 mobile::Module m_module = _load_for_mobile(test_model_file);
1821
1822 auto intrsuction_list =
1823 m_module.get_method("forward").function().get_code().instructions_;
1824 uint64_t number_of_call_instruction = 0;
1825 for (auto& instruction : intrsuction_list) {
1826 number_of_call_instruction += (instruction.op == OpCode::CALL);
1827 }
1828 // One operator will use upgrader
1829 ASSERT_EQ(number_of_call_instruction, 1);
1830
1831 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1832 auto output = m_module.forward(inputs);
1833 auto expect_output = 2.0 * torch::ones({1});
1834 auto actual_output = output.toTensor();
1835
1836 // The out argument will be overwritten with the output
1837 ASSERT_TRUE(actual_output.equal(expect_output));
1838 }
1839
TEST(LiteInterpreterUpgraderTest,DivScalarReciprocalFloatV2)1840 TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalFloatV2) {
1841 std::string filePath(__FILE__);
1842 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1843 test_model_file.append(
1844 "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl");
1845 /*
1846 (('__torch__.MyModuleFloat.forward',
1847 (('instructions',
1848 (('STOREN', 1, 3),
1849 ('DROPR', 1, 0),
1850 ('MOVE', 2, 0),
1851 ('OP', 0, 0),
1852 ('MOVE', 3, 0),
1853 ('OP', 1, 0),
1854 ('RET', 0, 0))),
1855 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1856 ('constants', ()),
1857 ('types', ()),
1858 ('register_size', 3))),)
1859 */
1860 mobile::Module m_module = _load_for_mobile(test_model_file);
1861
1862 auto intrsuction_list =
1863 m_module.get_method("forward").function().get_code().instructions_;
1864 uint64_t number_of_call_instruction = 0;
1865 for (auto& instruction : intrsuction_list) {
1866 number_of_call_instruction += (instruction.op == OpCode::CALL);
1867 }
1868 // No operator will use upgrader
1869 ASSERT_EQ(number_of_call_instruction, 0);
1870
1871 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1872 auto output = m_module.forward(inputs);
1873 auto expect_output = 0.5 * torch::ones({1});
1874 auto actual_output = output.toTensor();
1875 std::cout << "expect output: " << expect_output;
1876 std::cout << "actual output: " << actual_output;
1877 // The out argument will be overwritten with the output
1878 ASSERT_TRUE(actual_output.equal(expect_output));
1879 }
1880
TEST(LiteInterpreterUpgraderTest,DivScalarReciprocalIntV2)1881 TEST(LiteInterpreterUpgraderTest, DivScalarReciprocalIntV2) {
1882 std::string filePath(__FILE__);
1883 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1884 test_model_file.append(
1885 "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl");
1886 /*
1887 (('__torch__.MyModuleInt.forward',
1888 (('instructions',
1889 (('STOREN', 1, 3),
1890 ('DROPR', 1, 0),
1891 ('MOVE', 2, 0),
1892 ('OP', 0, 0),
1893 ('MOVE', 3, 0),
1894 ('OP', 1, 0),
1895 ('RET', 0, 0))),
1896 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1897 ('constants', ()),
1898 ('types', ()),
1899 ('register_size', 3))),)
1900 */
1901 mobile::Module m_module = _load_for_mobile(test_model_file);
1902
1903 auto intrsuction_list =
1904 m_module.get_method("forward").function().get_code().instructions_;
1905 uint64_t number_of_call_instruction = 0;
1906 for (auto& instruction : intrsuction_list) {
1907 number_of_call_instruction += (instruction.op == OpCode::CALL);
1908 }
1909 // No operator will use upgrader
1910 ASSERT_EQ(number_of_call_instruction, 0);
1911
1912 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1913 auto output = m_module.forward(inputs);
1914 auto expect_output = 0.5 * torch::ones({1});
1915 auto actual_output = output.toTensor();
1916
1917 // The out argument will be overwritten with the output
1918 ASSERT_TRUE(actual_output.equal(expect_output));
1919 }
1920
TEST(LiteInterpreterUpgraderTest,DivScalarScalarV2)1921 TEST(LiteInterpreterUpgraderTest, DivScalarScalarV2) {
1922 std::string filePath(__FILE__);
1923 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1924 test_model_file.append(
1925 "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl");
1926 /*
1927 (('__torch__.MyModule.forward',
1928 (('instructions',
1929 (('STOREN', 1, 5),
1930 ('DROPR', 1, 0),
1931 ('LOAD', 2, 0),
1932 ('LOAD', 3, 0),
1933 ('OP', 0, 0),
1934 ('MOVE', 2, 0),
1935 ('LOAD', 4, 0),
1936 ('OP', 1, 0),
1937 ('LOAD', 3, 0),
1938 ('MOVE', 4, 0),
1939 ('OP', 2, 0),
1940 ('MOVE', 3, 0),
1941 ('MOVE', 5, 0),
1942 ('OP', 3, 0),
1943 ('TUPLE_CONSTRUCT', 4, 0),
1944 ('RET', 0, 0))),
1945 ('operators',
1946 (('aten::div', ''),
1947 ('aten::div', 'float'),
1948 ('aten::div', ''),
1949 ('aten::div', 'int'))),
1950 ('constants', ()),
1951 ('types', ()),
1952 ('register_size', 5))),)
1953 */
1954 mobile::Module m_module = _load_for_mobile(test_model_file);
1955 auto intrsuction_list =
1956 m_module.get_method("forward").function().get_code().instructions_;
1957 uint64_t number_of_call_instruction = 0;
1958 for (auto& instruction : intrsuction_list) {
1959 number_of_call_instruction += (instruction.op == OpCode::CALL);
1960 }
1961 // No operator will use upgrader
1962 ASSERT_EQ(number_of_call_instruction, 0);
1963
1964 std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)};
1965 auto output = m_module.forward(inputs);
1966 auto output_list = output.toTupleRef().elements();
1967 auto expect_output = std::vector<IValue>(
1968 {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)});
1969 // auto actual_output = output.toTensor();
1970 for (size_t i = 0; i < expect_output.size(); i++) {
1971 ASSERT_EQ(output_list[i], expect_output[i]);
1972 }
1973 }
1974
TEST(LiteInterpreterUpgraderTest,DivScalarIntV2)1975 TEST(LiteInterpreterUpgraderTest, DivScalarIntV2) {
1976 std::string filePath(__FILE__);
1977 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1978 test_model_file.append(
1979 "upgrader_models/test_versioned_div_scalar_int_v2.ptl");
1980 /*
1981 (('__torch__.MyModuleInt.forward',
1982 (('instructions',
1983 (('STOREN', 1, 3),
1984 ('DROPR', 1, 0),
1985 ('MOVE', 2, 0),
1986 ('MOVE', 3, 0),
1987 ('OP', 0, 0),
1988 ('RET', 0, 0))),
1989 ('operators', (('aten::div', 'Scalar'),)),
1990 ('constants', ()),
1991 ('types', ()),
1992 ('register_size', 3))),)
1993 */
1994 mobile::Module m_module = _load_for_mobile(test_model_file);
1995
1996 auto intrsuction_list =
1997 m_module.get_method("forward").function().get_code().instructions_;
1998 uint64_t number_of_call_instruction = 0;
1999 for (auto& instruction : intrsuction_list) {
2000 number_of_call_instruction += (instruction.op == OpCode::CALL);
2001 }
2002 // One operator will use upgrader
2003 ASSERT_EQ(number_of_call_instruction, 1);
2004
2005 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
2006 auto output = m_module.forward(inputs);
2007 auto expect_output = 2.0 * torch::ones({1});
2008 auto actual_output = output.toTensor();
2009
2010 // The out argument will be overwritten with the output
2011 ASSERT_TRUE(actual_output.equal(expect_output));
2012 }
2013
TEST(LiteInterpreterUpgraderTest,DivScalarInplaceFloatV2)2014 TEST(LiteInterpreterUpgraderTest, DivScalarInplaceFloatV2) {
2015 std::string filePath(__FILE__);
2016 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
2017 test_model_file.append(
2018 "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl");
2019 /*
2020 (('__torch__.MyModuleFloat.forward',
2021 (('instructions',
2022 (('STOREN', 1, 3),
2023 ('DROPR', 1, 0),
2024 ('MOVE', 2, 0),
2025 ('MOVE', 3, 0),
2026 ('OP', 0, 0),
2027 ('RET', 0, 0))),
2028 ('operators', (('aten::div_', 'Scalar'),)),
2029 ('constants', ()),
2030 ('types', ()),
2031 ('register_size', 3))),)
2032 */
2033
2034 mobile::Module m_module = _load_for_mobile(test_model_file);
2035
2036 auto intrsuction_list =
2037 m_module.get_method("forward").function().get_code().instructions_;
2038 uint64_t number_of_call_instruction = 0;
2039 for (auto& instruction : intrsuction_list) {
2040 number_of_call_instruction += (instruction.op == OpCode::CALL);
2041 }
2042 // One operator will use upgrader
2043 ASSERT_EQ(number_of_call_instruction, 1);
2044
2045 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
2046 auto output = m_module.forward(inputs);
2047 auto expect_output = 2.0 * torch::ones({1});
2048 auto actual_output = output.toTensor();
2049
2050 // The out argument will be overwritten with the output
2051 ASSERT_TRUE(actual_output.equal(expect_output));
2052 }
2053
TEST(LiteInterpreterUpgraderTest,DivScalarInplaceIntV2)2054 TEST(LiteInterpreterUpgraderTest, DivScalarInplaceIntV2) {
2055 std::string filePath(__FILE__);
2056 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
2057 test_model_file.append(
2058 "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl");
2059 /*
2060 (('__torch__.MyModuleInt.forward',
2061 (('instructions',
2062 (('STOREN', 1, 3),
2063 ('DROPR', 1, 0),
2064 ('MOVE', 2, 0),
2065 ('MOVE', 3, 0),
2066 ('OP', 0, 0),
2067 ('RET', 0, 0))),
2068 ('operators', (('aten::div_', 'Scalar'),)),
2069 ('constants', ()),
2070 ('types', ()),
2071 ('register_size', 3))),)
2072 */
2073
2074 mobile::Module m_module = _load_for_mobile(test_model_file);
2075
2076 auto intrsuction_list =
2077 m_module.get_method("forward").function().get_code().instructions_;
2078 uint64_t number_of_call_instruction = 0;
2079 for (auto& instruction : intrsuction_list) {
2080 number_of_call_instruction += (instruction.op == OpCode::CALL);
2081 }
2082 // One operator will use upgrader
2083 ASSERT_EQ(number_of_call_instruction, 1);
2084
2085 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
2086 auto output = m_module.forward(inputs);
2087 auto expect_output = 2.0 * torch::ones({1});
2088 auto actual_output = output.toTensor();
2089
2090 // The out argument will be overwritten with the output
2091 ASSERT_TRUE(actual_output.equal(expect_output));
2092 }
2093
2094 #endif // !defined(FB_XPLAT_BUILD)
2095
TEST(LiteInterpreterUpgraderTest,Upgrader)2096 TEST(LiteInterpreterUpgraderTest, Upgrader) {
2097 std::vector<mobile::Function> upgrader_functions;
2098
2099 for (auto& byteCodeFunctionWithOperator : getUpgraderBytecodeList()) {
2100 byteCodeFunctionWithOperator.function.initialize_operators(true);
2101 ASSERT_EQ(
2102 byteCodeFunctionWithOperator.function.get_code().operators_.size(),
2103 byteCodeFunctionWithOperator.function.get_code().op_names_.size());
2104 if (byteCodeFunctionWithOperator.function.get_code().operators_.empty()) {
2105 for (const auto& op : byteCodeFunctionWithOperator.operators) {
2106 byteCodeFunctionWithOperator.function.append_operator(
2107 op.name, op.overload_name, op.num_specified_args);
2108 }
2109 }
2110 upgrader_functions.push_back(byteCodeFunctionWithOperator.function);
2111 }
2112
2113 ASSERT_EQ(getUpgraderBytecodeList().size(), upgrader_functions.size());
2114 }
2115
enumerateTupleType(size_t depth,std::vector<TypePtr> & current,const std::vector<TypePtr> & candidates,std::vector<TypePtr> & out)2116 void enumerateTupleType(
2117 size_t depth,
2118 std::vector<TypePtr>& current,
2119 const std::vector<TypePtr>& candidates,
2120 std::vector<TypePtr>& out) {
2121 static std::vector<std::string> fieldNames;
2122 if (depth > fieldNames.size()) {
2123 fieldNames.reserve(depth);
2124 for (size_t i = fieldNames.size(); i < depth; i++) {
2125 fieldNames.push_back("field" + std::to_string(i));
2126 }
2127 }
2128 if (depth == 0) {
2129 out.push_back(TupleType::create(current));
2130 while (fieldNames.size() > current.size()) {
2131 fieldNames.pop_back();
2132 }
2133 out.push_back(TupleType::createNamed("NamedTuple", fieldNames, current));
2134 return;
2135 }
2136 for (const auto& type : candidates) {
2137 if (containsAnyType(type)) {
2138 continue;
2139 }
2140 current.push_back(type);
2141 enumerateTupleType(depth - 1, current, candidates, out);
2142 current.pop_back();
2143 }
2144 }
2145
2146 class LiteInterpreterDynamicTypeTestFixture
2147 : public ::testing::TestWithParam<size_t> {
2148 protected:
SetUp()2149 void SetUp() override {
2150 cu = std::make_shared<CompilationUnit>();
2151 std::vector<TypePtr> keyTypes = {
2152 AnyType::get(),
2153 IntType::get(),
2154 BoolType::get(),
2155 FloatType::get(),
2156 ComplexType::get(),
2157 StringType::get(),
2158 TensorType::get(),
2159 DeviceObjType::get(),
2160 };
2161 types = {
2162 NoneType::get(),
2163 NumberType::get(),
2164 ClassType::create("__torch__.TestClass1", cu),
2165 ClassType::create("__torch__.TestClass2", cu),
2166 AnyListType::get(),
2167 AnyTupleType::get(),
2168 StreamObjType::get(),
2169 CapsuleType::get(),
2170 GeneratorType::get(),
2171 StorageType::get(),
2172 VarType::create("t"),
2173 VarType::create("v"),
2174 AnyClassType::get()};
2175 std::copy(keyTypes.begin(), keyTypes.end(), back_inserter(types));
2176 auto expandTypes = [&](size_t tupleSize) {
2177 std::vector<TypePtr> nested;
2178 for (const auto& type : types) {
2179 if (!(type == AnyType::get())) {
2180 nested.emplace_back(ListType::create(type));
2181 if (!(type == NoneType::get() ||
2182 type->kind() == OptionalType::Kind)) {
2183 nested.emplace_back(OptionalType::create(type));
2184 }
2185 }
2186 for (const auto& keyType : keyTypes) {
2187 nested.emplace_back(DictType::create(keyType, type));
2188 }
2189 }
2190 std::vector<TypePtr> tmp;
2191 enumerateTupleType(tupleSize, tmp, types, nested);
2192 std::move(
2193 std::begin(nested), std::end(nested), std::back_inserter(types));
2194 };
2195 expandTypes(1);
2196 expandTypes(1);
2197 }
2198 std::shared_ptr<CompilationUnit> cu;
2199 std::vector<TypePtr> types;
2200
2201 public:
2202 static constexpr size_t kNumSplits = 10;
2203 };
2204
2205 /**
2206 * Enumerate all possible JIT types appearing in mobile runtime, and test
2207 * whether subtyping relation is preserved after one of the JIT types is
2208 * converted to DynamicType.
2209 *
2210 * We firstly enumerate all "base" types in a vector, and implement
2211 * expandTypes() to enumerate container types one "level" up for a given set
2212 * of types. We call expandTypes() twice to test types nested less or equal
2213 * to two levels. e.g. List[Optional[Tensor]], Optional[Dict[Int, Bool]], etc.
2214 */
TEST_P(LiteInterpreterDynamicTypeTestFixture,Conformance)2215 TEST_P(LiteInterpreterDynamicTypeTestFixture, Conformance) {
2216 size_t num = types.size() / LiteInterpreterDynamicTypeTestFixture::kNumSplits;
2217 size_t begin = num * GetParam();
2218 size_t end = std::min(types.size(), begin + num);
2219 for (const auto& a : types) {
2220 auto da = DynamicType::create(*a);
2221 for (size_t i = begin; i < end; i++) {
2222 const auto& b = types[i];
2223 bool result = a->isSubtypeOf(*b);
2224 EXPECT_EQ(result, da->isSubtypeOf(*b));
2225 result = b->isSubtypeOf(*a);
2226 EXPECT_EQ(result, b->isSubtypeOf(*da));
2227 }
2228 }
2229 }
2230
2231 INSTANTIATE_TEST_SUITE_P(
2232 PyTorch,
2233 LiteInterpreterDynamicTypeTestFixture,
2234 ::testing::Range(
2235 static_cast<size_t>(0),
2236 LiteInterpreterDynamicTypeTestFixture::kNumSplits));
2237
2238 } // namespace jit
2239 } // namespace torch
2240