xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_flatbuffer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/jit/test_utils.h>
2 
3 #include <gtest/gtest.h>
4 
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/frontend/resolver.h>
9 #include <torch/csrc/jit/mobile/compatibility/backport.h>
10 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
11 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
12 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
13 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
14 #include <torch/csrc/jit/mobile/import.h>
15 #include <torch/csrc/jit/mobile/interpreter.h>
16 #include <torch/csrc/jit/mobile/module.h>
17 #include <torch/csrc/jit/mobile/parse_bytecode.h>
18 #include <torch/csrc/jit/mobile/parse_operators.h>
19 #include <torch/csrc/jit/serialization/export.h>
20 #include <torch/csrc/jit/serialization/export_bytecode.h>
21 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
22 #include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
23 #include <torch/csrc/jit/serialization/import.h>
24 #include <torch/custom_class.h>
25 #include <torch/torch.h>
26 
27 #include <caffe2/serialize/versions.h>
28 #include <torch/csrc/jit/serialization/import_export_functions.h>
29 #include <unordered_set>
30 
31 #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
32 #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
33 namespace flatbuffers = flatbuffers_fbsource;
34 #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
35 #else
36 #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
37 #endif
38 // Tests go in torch::jit
39 namespace torch {
40 namespace jit {
41 
42 namespace {
parse_mobile_module(void * data,size_t size,bool should_copy_tensor_memory=false)43 mobile::Module parse_mobile_module(
44     void* data,
45     size_t size,
46     bool should_copy_tensor_memory = false) {
47   return parse_and_initialize_mobile_module(
48       static_cast<char*>(data),
49       size,
50       /*device=*/std::nullopt,
51       /*extra_files=*/nullptr,
52       should_copy_tensor_memory);
53 }
54 } // namespace
55 
TEST(FlatbufferTest,LoadMalformedModule)56 TEST(FlatbufferTest, LoadMalformedModule) {
57   // Manually create some data with Flatbuffer header.
58   std::stringstream bad_data;
59   bad_data << "PK\x03\x04PTMF\x00\x00"
60            << "*}NV\xb3\xfa\xdf\x00pa";
61 
62   // Loading module from it should throw an exception.
63   // Check guard at parse_and_initialize_mobile_module_for_jit.
64   ASSERT_THROWS_WITH_MESSAGE(
65       torch::jit::load(bad_data), "Malformed Flatbuffer module");
66 
67   // Check guard at parse_and_initialize_mobile_module.
68   ASSERT_THROWS_WITH_MESSAGE(
69       parse_mobile_module(bad_data.str().data(), bad_data.str().size()),
70       "Malformed Flatbuffer module");
71 }
72 
TEST(FlatbufferTest,UpsampleNearest2d)73 TEST(FlatbufferTest, UpsampleNearest2d) {
74   Module m("m");
75   m.define(R"(
76     def forward(self, input: Tensor, scale:float):
77       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
78   )");
79 
80   std::vector<IValue> inputs;
81   inputs.emplace_back(torch::rand({1, 3, 128, 128}));
82   inputs.emplace_back(at::Scalar(2.0));
83   auto ref = m.forward(inputs);
84 
85   CompilationOptions options;
86   mobile::Module bc = jitModuleToMobile(m, options);
87   IValue res;
88   res = bc.forward(inputs);
89 
90   auto resd = res.toTensor();
91   auto refd = ref.toTensor();
92   ASSERT_TRUE(resd.equal(refd));
93 
94   auto buff = save_mobile_module_to_bytes(bc);
95   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
96   auto res2 = bc2.forward(inputs);
97   auto resd2 = res2.toTensor();
98   ASSERT_TRUE(resd2.equal(refd));
99 }
100 
TEST(FlatbufferTest,UpsampleNearest2dWithCopyTensorMemory)101 TEST(FlatbufferTest, UpsampleNearest2dWithCopyTensorMemory) {
102   Module m("m");
103   m.define(R"(
104     def forward(self, input: Tensor, scale:float):
105       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
106   )");
107 
108   std::vector<IValue> inputs;
109   inputs.emplace_back(torch::rand({1, 3, 128, 128}));
110   inputs.emplace_back(at::Scalar(2.0));
111   auto ref = m.forward(inputs);
112 
113   CompilationOptions options;
114   mobile::Module bc = jitModuleToMobile(m, options);
115   IValue res;
116   res = bc.forward(inputs);
117 
118   auto resd = res.toTensor();
119   auto refd = ref.toTensor();
120   ASSERT_TRUE(resd.equal(refd));
121 
122   auto buff = save_mobile_module_to_bytes(bc);
123   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
124 
125   auto res2 = bc2.forward(inputs);
126   auto resd2 = res2.toTensor();
127   ASSERT_TRUE(resd2.equal(refd));
128 }
129 
TEST(FlatbufferTest,CheckAttrAccess)130 TEST(FlatbufferTest, CheckAttrAccess) {
131   Module m("m");
132   m.register_attribute("mobile_optimized", BoolType::get(), true);
133 
134   CompilationOptions options;
135   mobile::Module bc = jitModuleToMobile(m, options);
136   bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
137 
138   AT_ASSERT(mobile_optimized);
139   m.setattr("mobile_optimized", false);
140   bc = jitModuleToMobile(m, options);
141   mobile_optimized = bc.attr("mobile_optimized", false).toBool();
142 
143   AT_ASSERT(!mobile_optimized);
144 
145   auto buff = save_mobile_module_to_bytes(bc);
146   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
147   auto mobile_optimized2 = bc2.attr("mobile_optimized", false).toBool();
148   AT_ASSERT(!mobile_optimized2);
149 }
150 
TEST(FlatbufferTest,MethodInvocation)151 TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest)
152   const std::vector<std::string> test_programs{
153       // test invoking a method with default parameter
154       R"(
155       def test_func(self, x, b : int = 4):
156         return self.foo + x + b
157       )",
158       // inner method call with default parameter (gets inlined)
159       R"(
160       def add_with_default_arg(self, x, b : int = 4):
161         return self.foo + x + b
162       def test_func(self, x):
163         return self.add_with_default_arg(x)  # invoke method w/ default arg
164       )",
165       // simple method call
166       R"(
167       def test_func(self, x):
168         b = 4
169         return self.foo + x + b
170       )",
171   };
172   for (const auto& test_program : test_programs) {
173     Module m("m");
174     m.register_parameter("foo", torch::ones({}), false);
175     m.define(test_program);
176 
177     const int fortyTwo = 42; // (keep linter happy)
178     auto minput = fortyTwo * torch::ones({});
179     auto ref = m.run_method("test_func", minput);
180 
181     CompilationOptions options;
182     mobile::Module bc = jitModuleToMobile(m, options);
183     const auto& test_func = bc.get_method("test_func");
184     IValue res;
185     for (int i = 0; i < 3; ++i) {
186       res = test_func({minput});
187     }
188 
189     auto resd = res.toTensor().item<float>();
190     auto refd = ref.toTensor().item<float>();
191     AT_ASSERT(resd == refd);
192 
193     auto buff = save_mobile_module_to_bytes(bc);
194     mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
195     const auto& test_func2 = bc2.get_method("test_func");
196     IValue res2;
197     for (int i = 0; i < 3; ++i) {
198       res2 = test_func2({minput});
199     }
200     auto resd2 = res2.toTensor().item<float>();
201     AT_ASSERT(resd2 == refd);
202   }
203 }
204 
205 #if !defined(FB_XPLAT_BUILD)
TEST(FlatbufferTest,FlatbufferBackPortTest)206 TEST(FlatbufferTest, FlatbufferBackPortTest) {
207   Module m("m");
208   m.define(R"(
209     def forward(self, input: Tensor, scale:float):
210       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
211   )");
212   std::stringstream ss;
213   m._save_for_mobile(ss, {}, false, true);
214 
215   std::stringstream oss;
216   bool backPortSuccess = _backport_for_mobile(ss, oss, 5);
217   ASSERT_TRUE(backPortSuccess);
218 }
219 #endif // !defined(FB_XPLAT_BUILD)
220 
TEST(FlatbufferTest,ExtraFiles)221 TEST(FlatbufferTest, ExtraFiles) {
222   const auto script = R"JIT(
223     def forward(self):
224         x = torch.rand(5, 5)
225         x = x.mm(x)
226         return x
227   )JIT";
228 
229   auto module =
230       std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
231   module->define(script);
232   std::ostringstream oss;
233   std::unordered_map<std::string, std::string> extra_files;
234   extra_files["metadata.json"] = "abc";
235   extra_files["mobile_info.json"] = "{\"key\": 23}";
236 
237   std::unordered_map<std::string, std::string> loaded_extra_files;
238   std::stringstream ss;
239   module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
240 
241   loaded_extra_files["metadata.json"] = "";
242   auto mobile_module = _load_for_mobile(ss, std::nullopt, loaded_extra_files);
243 
244   ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
245   ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
246 
247   // load it twice using the same stream
248   auto mobile_module2 = _load_for_mobile(ss, std::nullopt, loaded_extra_files);
249 
250   ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
251   ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
252 
253   // Test if flatbuffer does not require any explicit key entries mapping in the
254   // extra file map.
255   std::unordered_map<std::string, std::string>
256       loaded_extra_files_without_explicit_entries;
257   auto mobile_module3 = _load_for_mobile(
258       ss,
259       std::nullopt,
260       loaded_extra_files_without_explicit_entries,
261       MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS);
262 
263   ASSERT_EQ(
264       loaded_extra_files_without_explicit_entries["metadata.json"], "abc");
265   ASSERT_EQ(
266       loaded_extra_files_without_explicit_entries["mobile_info.json"],
267       "{\"key\": 23}");
268 }
269 
TEST(FlatbufferTest,Conv)270 TEST(FlatbufferTest, Conv) {
271   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
272   if (s && strcmp(s, "1") == 0)
273     return;
274 
275   std::vector<torch::jit::IValue> inputs;
276 
277   Module m("m");
278   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
279   m.register_parameter("bias", torch::ones({20}), false);
280   m.define(R"(
281     def forward(self, input):
282       return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
283   )");
284 
285   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
286   inputs.push_back(torch::ones({1, 1, 28, 28}));
287 
288   auto outputref = m.forward(inputs).toTensor();
289 
290   CompilationOptions options;
291   mobile::Module bc = jitModuleToMobile(m, options);
292   IValue res;
293   for (int i = 0; i < 3; ++i) {
294     res = bc.get_method("forward")(inputs);
295   }
296   auto output = res.toTensor();
297   AT_ASSERT(outputref.dim() == output.dim());
298   AT_ASSERT(
299       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
300 
301   auto buff = save_mobile_module_to_bytes(bc);
302   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
303   for (int i = 0; i < 3; ++i) {
304     res = bc2.get_method("forward")(inputs);
305   }
306   output = res.toTensor();
307   AT_ASSERT(outputref.dim() == output.dim());
308   AT_ASSERT(
309       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
310 }
311 
TEST(FlatbufferTest,ConvWithCopyTensorMemory)312 TEST(FlatbufferTest, ConvWithCopyTensorMemory) {
313   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
314   if (s && strcmp(s, "1") == 0)
315     return;
316 
317   std::vector<torch::jit::IValue> inputs;
318 
319   Module m("m");
320   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
321   m.register_parameter("bias", torch::ones({20}), false);
322   m.define(R"(
323     def forward(self, input):
324       return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
325   )");
326 
327   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
328   inputs.push_back(torch::ones({1, 1, 28, 28}));
329 
330   auto outputref = m.forward(inputs).toTensor();
331 
332   CompilationOptions options;
333   mobile::Module bc = jitModuleToMobile(m, options);
334   IValue res;
335   for (int i = 0; i < 3; ++i) {
336     res = bc.get_method("forward")(inputs);
337   }
338   auto output = res.toTensor();
339   AT_ASSERT(outputref.dim() == output.dim());
340   AT_ASSERT(
341       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
342 
343   auto buff = save_mobile_module_to_bytes(bc);
344   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
345 
346   for (int i = 0; i < 3; ++i) {
347     res = bc2.get_method("forward")(inputs);
348   }
349   output = res.toTensor();
350   AT_ASSERT(outputref.dim() == output.dim());
351   AT_ASSERT(
352       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
353 }
354 
TEST(FlatbufferTest,Inline)355 TEST(FlatbufferTest, Inline) {
356   Module m("m");
357   m.define(R"JIT(
358   def foo1(self, x):
359       return x + 1
360 
361   def foo2(self, x):
362       return self.foo1(x) + 2
363 
364   def foo3(self, x):
365       return self.foo2(x) + 3
366   )JIT");
367   CompilationOptions options;
368   mobile::Module bc = jitModuleToMobile(m, options);
369   std::vector<torch::jit::IValue> inputs({torch::ones({})});
370   auto output = bc.get_method("foo3")(inputs);
371   AT_ASSERT(output.toTensor().item<float>() == 7.0);
372 
373   auto buff = save_mobile_module_to_bytes(bc);
374   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
375   std::vector<torch::jit::IValue> inputs2({torch::ones({})});
376   output = bc2.get_method("foo3")(inputs2);
377   AT_ASSERT(output.toTensor().item<float>() == 7.0);
378 }
379 
TEST(FlatbufferTest,InlineWithCopyTensorMemory)380 TEST(FlatbufferTest, InlineWithCopyTensorMemory) {
381   Module m("m");
382   m.define(R"JIT(
383   def foo1(self, x):
384       return x + 1
385 
386   def foo2(self, x):
387       return self.foo1(x) + 2
388 
389   def foo3(self, x):
390       return self.foo2(x) + 3
391   )JIT");
392   CompilationOptions options;
393   mobile::Module bc = jitModuleToMobile(m, options);
394   std::vector<torch::jit::IValue> inputs({torch::ones({})});
395   auto output = bc.get_method("foo3")(inputs);
396   AT_ASSERT(output.toTensor().item<float>() == 7.0);
397 
398   auto buff = save_mobile_module_to_bytes(bc);
399   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
400   std::vector<torch::jit::IValue> inputs2({torch::ones({})});
401   output = bc2.get_method("foo3")(inputs2);
402   AT_ASSERT(output.toTensor().item<float>() == 7.0);
403 }
404 
TEST(FlatbufferTest,Tuple)405 TEST(FlatbufferTest, Tuple) {
406   Module m("m");
407   m.define(R"JIT(
408   def foo(self, x):
409       return (1, 2, x + 3)
410 
411   def forward(self, x):
412       tuple = self.foo(x)
413       return tuple
414   )JIT");
415   CompilationOptions options;
416   mobile::Module bc = jitModuleToMobile(m, options);
417   std::vector<torch::jit::IValue> inputs({torch::ones({})});
418   auto output = bc.get_method("forward")(inputs);
419   AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
420 
421   auto buff = save_mobile_module_to_bytes(bc);
422   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
423   output = bc2.get_method("forward")(inputs);
424   AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
425 }
426 
TEST(FlatbufferTest,Dict)427 TEST(FlatbufferTest, Dict) {
428   Module m("m");
429   m.define(R"JIT(
430   def foo(self, x):
431       return {"result": x + 1}
432 
433   def forward(self, x):
434       d = self.foo(x)
435       return d
436   )JIT");
437   CompilationOptions options;
438   mobile::Module bc = jitModuleToMobile(m, options);
439   std::vector<torch::jit::IValue> inputs({torch::ones({})});
440   auto output = bc.get_method("forward")(inputs);
441   AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
442 
443   auto buff = save_mobile_module_to_bytes(bc);
444   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
445   output = bc2.get_method("forward")(inputs);
446   AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
447 }
448 
TEST(FlatbufferTest,Prim)449 TEST(FlatbufferTest, Prim) {
450   Module m("m");
451   m.define(R"JIT(
452         def forward(self, x):
453             return int(x)
454   )JIT");
455 
456   std::vector<IValue> inputs;
457   auto minput = 3.5 * torch::ones({});
458   inputs.emplace_back(minput);
459   auto ref = m.run_method("forward", minput);
460 
461   CompilationOptions options;
462   mobile::Module bc = jitModuleToMobile(m, options);
463   IValue res;
464   for (int i = 0; i < 3; ++i) {
465     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
466     auto bcinputs = inputs;
467     res = bc.get_method("forward")(bcinputs);
468   }
469 
470   auto resi = res.toInt();
471   auto refi = ref.toInt();
472   AT_ASSERT(resi == refi);
473 
474   auto buff = save_mobile_module_to_bytes(bc);
475   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
476   for (int i = 0; i < 3; ++i) {
477     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
478     auto bcinputs = inputs;
479     res = bc2.get_method("forward")(bcinputs);
480   }
481   auto resi2 = res.toInt();
482   AT_ASSERT(resi2 == refi);
483 }
484 
TEST(FlatbufferTest,PrimScalar)485 TEST(FlatbufferTest, PrimScalar) {
486   Module m("m");
487   m.define(R"JIT(
488         def forward(self, x):
489             return int(x.item())
490   )JIT");
491 
492   std::vector<IValue> inputs;
493   auto minput = 3.5 * torch::ones({});
494   inputs.emplace_back(minput);
495   auto ref = m.run_method("forward", minput);
496 
497   CompilationOptions options;
498   mobile::Module bc = jitModuleToMobile(m, options);
499   IValue res;
500   for (int i = 0; i < 3; ++i) {
501     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
502     auto bcinputs = inputs;
503     res = bc.get_method("forward")(bcinputs);
504   }
505 
506   auto resi = res.toInt();
507   auto refi = ref.toInt();
508   AT_ASSERT(resi == refi);
509 
510   auto buff = save_mobile_module_to_bytes(bc);
511   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
512   for (int i = 0; i < 3; ++i) {
513     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
514     auto bcinputs = inputs;
515     res = bc2.get_method("forward")(bcinputs);
516   }
517   auto resi2 = res.toInt();
518   AT_ASSERT(resi2 == refi);
519 }
520 
TEST(FlatbufferTest,WrongMethodName)521 TEST(FlatbufferTest, WrongMethodName) {
522   Module m("m");
523   m.register_parameter("foo", torch::ones({}), false);
524   m.define(R"(
525     def add(self, x):
526       b = 4
527       return self.foo + x + b
528   )");
529   CompilationOptions options;
530   mobile::Module bc = jitModuleToMobile(m, options);
531   std::vector<IValue> inputs;
532   auto minput = 5 * torch::ones({});
533   inputs.emplace_back(minput);
534   ASSERT_THROWS_WITH_MESSAGE(
535       bc.get_method("forward")(inputs), "is not defined");
536 
537   auto buff = save_mobile_module_to_bytes(bc);
538   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
539   ASSERT_THROWS_WITH_MESSAGE(
540       bc2.get_method("forward")(inputs), "is not defined");
541 }
542 
TEST(FlatbufferTest,SetState)543 TEST(FlatbufferTest, SetState) {
544   Module m("m");
545   m.register_parameter("foo", torch::ones({}), false);
546   m.define(R"(
547     def __getstate__(self):
548       return self.foo
549     def __setstate__(self, a):
550       self.foo = a
551     def forward(self, x):
552       b = 4
553       return self.foo + x + b
554   )");
555 
556   std::vector<IValue> inputs;
557   auto minput = 5 * torch::ones({});
558   inputs.emplace_back(minput);
559 
560   std::stringstream ms;
561   m.save(ms);
562   auto loaded_m = load(ms);
563   auto ref = loaded_m.run_method("forward", minput);
564 
565   CompilationOptions options;
566   mobile::Module bc = jitModuleToMobile(m, options);
567   IValue res;
568   for (int i = 0; i < 3; ++i) {
569     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
570     auto bcinputs = inputs;
571     res = bc.get_method("forward")(bcinputs);
572   }
573 
574   auto resd = res.toTensor().item<float>();
575   auto refd = ref.toTensor().item<float>();
576   AT_ASSERT(resd == refd);
577 
578   auto buff = save_mobile_module_to_bytes(bc);
579   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
580   for (int i = 0; i < 3; ++i) {
581     // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
582     auto bcinputs = inputs;
583     res = bc2.get_method("forward")(bcinputs);
584   }
585 
586   auto resd2 = res.toTensor().item<float>();
587   AT_ASSERT(resd2 == refd);
588 }
589 
590 class TorchBindFlatbufferTestStruct : public torch::jit::CustomClassHolder {
591  public:
get(at::Tensor t)592   std::string get(at::Tensor t) {
593     std::stringstream ss;
594     ss << "Hello! Your tensor has ";
595     ss << t.numel();
596     ss << " elements!";
597     return ss.str();
598   }
599 };
600 
601 namespace {
602 struct ClassNamespaceValue : public SugaredValue {
ClassNamespaceValuetorch::jit::__anon8587c2c70211::ClassNamespaceValue603   explicit ClassNamespaceValue(c10::QualifiedName name)
604       : basename_(std::move(name)) {}
605 
attrtorch::jit::__anon8587c2c70211::ClassNamespaceValue606   std::shared_ptr<SugaredValue> attr(
607       const SourceRange& loc,
608       GraphFunction& m,
609       const std::string& name) override {
610     const auto fullName = c10::QualifiedName(basename_, name);
611 
612     // Check to see if it is a custom class.
613     if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
614       return std::make_shared<ClassValue>(custom_class);
615     }
616 
617     // If it's not a custom class, assume it's another namespace
618     // NOLINTNEXTLINE(performance-move-const-arg)
619     return std::make_shared<ClassNamespaceValue>(std::move(fullName));
620   }
621 
kindtorch::jit::__anon8587c2c70211::ClassNamespaceValue622   std::string kind() const override {
623     return "Class Namespace";
624   }
625 
626  private:
627   c10::QualifiedName basename_;
628 };
629 
630 struct TestModuleResolver : public Resolver {
resolveValuetorch::jit::__anon8587c2c70211::TestModuleResolver631   std::shared_ptr<SugaredValue> resolveValue(
632       const std::string& name,
633       GraphFunction& m,
634       const SourceRange& loc) override {
635     if (name == "torch") {
636       return std::make_shared<BuiltinModule>("aten");
637     } else if (name == "__torch__") {
638       return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
639     }
640 
641     return nullptr;
642   }
643 
resolveTypetorch::jit::__anon8587c2c70211::TestModuleResolver644   TypePtr resolveType(const std::string& name, const SourceRange& loc)
645       override {
646     return nullptr;
647   }
648 };
649 } // namespace
650 
TEST(FlatbufferTest,BuiltinClass)651 TEST(FlatbufferTest, BuiltinClass) {
652   script::Module m("m");
653 
654   auto cls = getCustomClass(
655       "__torch__.torch.classes._TorchScriptTesting._FlatbufferTest");
656   TORCH_INTERNAL_ASSERT(cls);
657   c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
658   m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
659 
660   m.register_parameter("foo", torch::ones({}), false);
661   m.define(
662       R"(
663     def __getstate__(self):
664       return 1
665     def __setstate__(self, a):
666       self.my_obj = __torch__.torch.classes._TorchScriptTesting._FlatbufferTest()
667 
668     def forward(self, x) -> str:
669       return self.my_obj.get(x)
670   )",
671       std::make_shared<TestModuleResolver>());
672 
673   CompilationOptions options;
674   mobile::Module bc = jitModuleToMobile(m, options);
675   auto buff = save_mobile_module_to_bytes(bc);
676   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
677   std::string expected = "Hello! Your tensor has 12 elements!";
678   auto res =
679       bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
680   const auto& str2 = res.toStringRef();
681   AT_ASSERT(str2 == expected);
682 }
683 
TEST(FlatbufferTest,BuiltinFunction)684 TEST(FlatbufferTest, BuiltinFunction) {
685   script::Module m("m");
686   auto custom_class_obj = make_custom_class<TorchBindFlatbufferTestStruct>();
687   m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
688   m.define(R"(
689     def forward(self, x) -> str:
690       return self.my_obj.get(x)
691   )");
692 
693   CompilationOptions options;
694   mobile::Module bc = jitModuleToMobile(m, options);
695   auto res =
696       bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
697   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
698   auto str = res.toStringRef();
699   std::string expected = "Hello! Your tensor has 12 elements!";
700   AT_ASSERT(str == expected);
701 
702   auto buff = save_mobile_module_to_bytes(bc);
703   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
704   res = bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
705   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
706   str = res.toStringRef();
707   AT_ASSERT(str == expected);
708 }
709 
TEST(FlatbufferTest,Eval)710 TEST(FlatbufferTest, Eval) {
711   std::vector<torch::jit::IValue> inputs;
712 
713   Module m("m");
714   m.define(R"(
715     def __init__(self, x):
716       self.training = True
717 
718     def forward(self, input):
719       return torch.dropout(input, 1.0, self.training)
720   )");
721 
722   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
723   inputs.push_back(torch::ones({1, 1, 28, 28}));
724   m.eval();
725   auto outputref = m.forward(inputs).toTensor();
726 
727   // save m in training mode to make sure that mobile eval() will correctly
728   // change back to eval mode
729   m.train();
730   CompilationOptions options;
731   mobile::Module bc = jitModuleToMobile(m, options);
732   bc.eval();
733   IValue res;
734   for (int i = 0; i < 3; ++i) {
735     res = bc.get_method("forward")(inputs);
736   }
737   auto output = res.toTensor();
738   AT_ASSERT(outputref.dim() == output.dim());
739   AT_ASSERT(
740       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
741 
742   auto buff = save_mobile_module_to_bytes(bc);
743   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
744   bc2.eval();
745   for (int i = 0; i < 3; ++i) {
746     res = bc2.get_method("forward")(inputs);
747   }
748   output = res.toTensor();
749   AT_ASSERT(outputref.dim() == output.dim());
750   AT_ASSERT(
751       outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
752 }
753 
TEST(FlatbufferTest,FindWrongMethodName)754 TEST(FlatbufferTest, FindWrongMethodName) {
755   Module m("m");
756   m.register_parameter("foo", torch::ones({}), false);
757   m.define(R"(
758     def add(self, x):
759       b = 4
760       return self.foo + x + b
761   )");
762   CompilationOptions options;
763   mobile::Module bc = jitModuleToMobile(m, options);
764   ASSERT_TRUE(bc.find_method("forward") == std::nullopt);
765 
766   auto buff = save_mobile_module_to_bytes(bc);
767   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
768   ASSERT_TRUE(bc2.find_method("forward") == std::nullopt);
769 }
770 
TEST(FlatbufferTest,FindAndRunMethod)771 TEST(FlatbufferTest, FindAndRunMethod) {
772   Module m("m");
773   m.register_parameter("foo", torch::ones({}), false);
774   m.define(R"(
775     def add_it(self, x):
776       b = 4
777       return self.foo + x + b
778   )");
779 
780   std::vector<IValue> inputs;
781   auto minput = 5 * torch::ones({});
782   inputs.emplace_back(minput);
783   auto ref = m.get_method("add_it")(inputs);
784 
785   CompilationOptions options;
786   mobile::Module bc = jitModuleToMobile(m, options);
787   IValue res;
788   for (int i = 0; i < 3; ++i) {
789     auto bcinputs = inputs;
790     auto method = bc.find_method("add_it");
791     AT_ASSERT(method != std::nullopt);
792     res = (*method)(std::move(bcinputs));
793   }
794 
795   auto resd = res.toTensor().item<float>();
796   auto refd = ref.toTensor().item<float>();
797   AT_ASSERT(resd == refd);
798 
799   auto buff = save_mobile_module_to_bytes(bc);
800   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
801 
802   for (int i = 0; i < 3; ++i) {
803     auto bcinputs = inputs;
804     auto method = bc2.find_method("add_it");
805     AT_ASSERT(method != std::nullopt);
806     res = (*method)(std::move(bcinputs));
807   }
808 
809   resd = res.toTensor().item<float>();
810   AT_ASSERT(resd == refd);
811 }
812 
TEST(FlatbufferTest,RunMethodVariadic)813 TEST(FlatbufferTest, RunMethodVariadic) {
814   Module m("m");
815   m.register_parameter("foo", torch::ones({}), false);
816   m.define(R"(
817     def add_three(self, x, y):
818       return self.foo + x + y
819   )");
820 
821   std::vector<IValue> inputs;
822   auto inputx = 5 * torch::ones({});
823   auto inputy = 4 * torch::ones({});
824   auto ref = m.run_method("add_three", inputx, inputy);
825 
826   CompilationOptions options;
827   mobile::Module bc = jitModuleToMobile(m, options);
828   IValue res = bc.run_method("add_three", inputx, inputy);
829 
830   auto resd = res.toTensor().item<float>();
831   auto refd = ref.toTensor().item<float>();
832   AT_ASSERT(resd == refd);
833 
834   auto buff = save_mobile_module_to_bytes(bc);
835   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
836   res = bc.run_method("add_three", inputx, inputy);
837   resd = res.toTensor().item<float>();
838   AT_ASSERT(resd == refd);
839 }
840 
TEST(FlatbufferTest,DuplicateSetState)841 TEST(FlatbufferTest, DuplicateSetState) {
842   Module m("M");
843   m.register_parameter("foo", torch::ones({}), false);
844   m.define(R"(
845     def __getstate__(self):
846       return self.foo + self.foo
847     def __setstate__(self, a):
848       self.foo = a
849     def forward(self, x):
850       b = 4
851       return self.foo + x + b
852   )");
853 
854   Module b("B");
855   b.register_module("M0", m);
856   b.register_module("M1", m);
857   b.define(R"(
858     def forward(self, x):
859       return self.M0.forward(x) + self.M1.forward(x)
860   )");
861 
862   CompilationOptions options;
863   mobile::Module bc = jitModuleToMobile(m, options);
864   const auto methods = bc.get_methods();
865   const size_t expected_n = 3;
866   ASSERT_EQ(methods.size(), expected_n);
867 
868   auto buff = save_mobile_module_to_bytes(bc);
869   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
870   const auto methods2 = bc.get_methods();
871   ASSERT_EQ(methods2.size(), expected_n);
872 }
873 
TEST(FlatbufferTest,OpNameExportFetchRootOperators)874 TEST(FlatbufferTest, OpNameExportFetchRootOperators) {
875   torch::jit::Module m("m");
876   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
877   m.register_parameter("bias", torch::ones({20}), false);
878   m.define(R"(
879     def forward(self, input):
880       x1 = torch.zeros(2, 2)
881       x2 = torch.empty_like(torch.empty(2, 2))
882       x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
883       return (x1, x2, x3)
884   )");
885   m.eval();
886 
887   CompilationOptions options;
888   mobile::Module ptl_model = jitModuleToMobile(m, options);
889   std::set<std::string> operator_names =
890       torch::jit::mobile::_export_operator_list(ptl_model);
891   std::set<std::string> expected_operator_names = {
892       "aten::_convolution",
893       "aten::empty.memory_format",
894       "aten::empty_like",
895       "aten::zeros",
896   };
897   EXPECT_EQ(operator_names, expected_operator_names)
898       << "Expected the root operator lists to be the same";
899 
900   auto buff = save_mobile_module_to_bytes(ptl_model);
901   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
902   operator_names = torch::jit::mobile::_export_operator_list(bc2);
903   EXPECT_EQ(operator_names, expected_operator_names)
904       << "Expected the root operator lists to be the same";
905 }
906 
TEST(FlatbufferTest,DefaultArgsConv)907 TEST(FlatbufferTest, DefaultArgsConv) {
908   auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
909   if (s && strcmp(s, "1") == 0)
910     return;
911 
912   std::vector<torch::jit::IValue> inputs;
913 
914   Module m("m");
915   m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
916   m.register_parameter("bias", torch::ones({20}), false);
917   m.define(R"(
918     def forward(self, input):
919       return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
920   )");
921 
922   inputs.emplace_back(torch::ones({1, 1, 28, 28}));
923 
924   auto outputref = m.forward(inputs).toTensor();
925 
926   CompilationOptions options;
927   mobile::Module bc = jitModuleToMobile(m, options);
928   IValue res;
929   for (int i = 0; i < 1; ++i) {
930     res = bc.get_method("forward")(inputs);
931   }
932   auto output = res.toTensor();
933   AT_ASSERT(outputref.dim() == output.dim());
934   AT_ASSERT(output.equal(outputref));
935 
936   auto buff = save_mobile_module_to_bytes(bc);
937   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
938   for (int i = 0; i < 1; ++i) {
939     res = bc2.get_method("forward")(inputs);
940   }
941   output = res.toTensor();
942   AT_ASSERT(outputref.dim() == output.dim());
943   AT_ASSERT(output.equal(outputref));
944 }
945 
946 namespace {
testLiteModuleCompareResultTensors(Module & m,const std::vector<torch::jit::IValue> & inputs,const std::string & method_name="forward")947 void testLiteModuleCompareResultTensors(
948     Module& m,
949     const std::vector<torch::jit::IValue>& inputs,
950     const std::string& method_name = "forward") {
951   auto outputref = m.get_method(method_name)(inputs).toTensor();
952 
953   CompilationOptions options;
954   mobile::Module bc = jitModuleToMobile(m, options);
955   IValue res;
956   for (int i = 0; i < 3; ++i) {
957     res = bc.get_method(method_name)(inputs);
958   }
959   auto output = res.toTensor();
960   AT_ASSERT(outputref.dim() == output.dim());
961   AT_ASSERT(output.equal(outputref));
962 
963   auto buff = save_mobile_module_to_bytes(bc);
964   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
965   for (int i = 0; i < 3; ++i) {
966     res = bc2.get_method(method_name)(inputs);
967   }
968   output = res.toTensor();
969   AT_ASSERT(outputref.dim() == output.dim());
970   AT_ASSERT(output.equal(outputref));
971 }
972 
testDefaultArgsPinv(int num_args)973 static void testDefaultArgsPinv(int num_args) {
974   Module m("m");
975   if (num_args == 1) {
976     m.define(R"(
977       def forward(self, input):
978         return torch.linalg_pinv(input)
979     )");
980   } else if (num_args == 2) {
981     m.define(R"(
982       def forward(self, input):
983         return torch.linalg_pinv(input, 1e-5)
984     )");
985   } else if (num_args == 3) {
986     m.define(R"(
987       def forward(self, input):
988         return torch.linalg_pinv(input, 1e-5, True)
989     )");
990   }
991 
992   std::vector<torch::jit::IValue> inputs;
993   const int N = 28;
994   auto input = torch::range(1, N * N, 1);
995   input[0] = 1; // a more stable matrix
996   input = input.view({N, N});
997   inputs.emplace_back(input);
998   testLiteModuleCompareResultTensors(m, inputs);
999 }
1000 } // namespace
1001 
1002 #if !defined FB_XPLAT_BUILD
TEST(FlatbufferTest,DefaultArgsPinv)1003 TEST(FlatbufferTest, DefaultArgsPinv) {
1004   // Test with different number of specified arguments.
1005   // Arguments not specified take default value.
1006   for (int num_args = 1; num_args <= 3; ++num_args) {
1007     testDefaultArgsPinv(num_args);
1008   }
1009 
1010   //  bytecode with one specified argument:
1011   //  (6,
1012   //      ('__torch__.m.forward',
1013   //          (('instructions',
1014   //              (('STOREN', 1, 2),
1015   //                  ('DROPR', 1, 0),
1016   //                  ('MOVE', 2, 0),
1017   //                  ('OP', 0, 0),
1018   //                  ('RET', 0, 0))),
1019   //              ('operators', (('aten::linalg_pinv', '', 1),)),
1020   //              ('constants', (False, 1e-15)), # default constants are not
1021   //              used
1022   //              ('types', ()),
1023   //              ('register_size', 2)),
1024   //          (('arguments',
1025   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1026   //              None)),
1027   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
1028   //                  None)))),
1029   //              ('returns',
1030   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
1031   //                  None)),)))))
1032 
1033   //  bytecode with 2 specified argument:
1034   //  (6,
1035   //      ('__torch__.m.forward',
1036   //          (('instructions',
1037   //              (('STOREN', 1, 2),
1038   //                  ('DROPR', 1, 0),
1039   //                  ('MOVE', 2, 0),
1040   //                  ('LOADC', 1, 0), # added LOADC for specified argument
1041   //                  ('OP', 0, 0),
1042   //                  ('RET', 0, 0))),
1043   //              ('operators', (('aten::linalg_pinv', '', 2),)),
1044   //              ('constants', (False, 1e-05)), # updated constant table
1045   //              ('types', ()),
1046   //              ('register_size', 2)),
1047   //          (('arguments',
1048   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1049   //              None)),
1050   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
1051   //                  None)))),
1052   //              ('returns',
1053   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
1054   //                  None)),)))))
1055 
1056   //  bytecode with 3 specified arguments:
1057   //  (6,
1058   //      ('__torch__.m.forward',
1059   //          (('instructions',
1060   //              (('STOREN', 1, 2),
1061   //                  ('DROPR', 1, 0),
1062   //                  ('MOVE', 2, 0),
1063   //                  ('LOADC', 1, 0),
1064   //                  ('LOADC', 0, 0),
1065   //                  ('OP', 0, 0),
1066   //                  ('RET', 0, 0))),
1067   //              ('operators', (('aten::linalg_pinv', '', 3),)),
1068   //              ('constants', (True, 1e-05)),
1069   //              ('types', ()),
1070   //              ('register_size', 2)),
1071   //          (('arguments',
1072   //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1073   //              None)),
1074   //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
1075   //                  None)))),
1076   //              ('returns',
1077   //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
1078   //                  None)),)))))
1079 }
1080 
TEST(FlatbufferTest,DefaultArgsTensorinvSpecifyDefault)1081 TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) {
1082   // The second argument is specified, but the value is the same as the default
1083   // value. It's treated as "not specified" since the value can be fetched from
1084   // schema.
1085   Module m("m");
1086   m.define(R"(
1087     def forward(self, input):
1088       return torch.linalg_tensorinv(input, 2)
1089   )");
1090   torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
1091   auto arg_nums = code.op_to_num_specified_args();
1092   ASSERT_EQ(arg_nums.size(), 1);
1093   ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
1094   std::vector<torch::jit::IValue> inputs;
1095   const int N = 4;
1096   auto input = torch::rand({N, N, N, N});
1097   inputs.emplace_back(input);
1098   testLiteModuleCompareResultTensors(m, inputs);
1099 }
1100 
testDefaultArgsPinvWithOutArg(int num_args)1101 static void testDefaultArgsPinvWithOutArg(int num_args) {
1102   Module m("m");
1103   if (num_args == 1) {
1104     m.define(R"(
1105       def forward(self, input):
1106         return torch.linalg_pinv(input, out=input)
1107     )");
1108   } else if (num_args == 2) {
1109     m.define(R"(
1110       def forward(self, input):
1111         return torch.linalg_pinv(input, 1e-5, out=input)
1112     )");
1113   } else if (num_args == 3) {
1114     m.define(R"(
1115       def forward(self, input):
1116         return torch.linalg_pinv(input, 1e-5, True, out=input)
1117     )");
1118   }
1119 
1120   const int N = 28;
1121   auto input = torch::range(1, N * N, 1);
1122   input[0] = 10000; // a more stable matrix
1123   input = input.view({N, N});
1124   auto ref = m.run_method("forward", input);
1125   TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
1126   TORCH_CHECK(input.equal(ref.toTensor()));
1127 }
1128 
TEST(FlatbufferTest,DefaultArgsPinvWithOutArg)1129 TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
1130   // Test with different number of specified arguments + out arg.
1131   // Arguments not specified take default value.
1132   for (int num_args = 1; num_args <= 3; ++num_args) {
1133     testDefaultArgsPinvWithOutArg(num_args);
1134   }
1135 }
1136 
TEST(FlatbufferTest,DefaultArgsWithOutArg)1137 TEST(FlatbufferTest, DefaultArgsWithOutArg) {
1138   Module m("m");
1139   m.define(R"(
1140     def forward(self, x, h):
1141       torch.add(x, h, out=x)
1142   )");
1143 
1144   std::vector<IValue> inputs;
1145   auto input_x = 2 * torch::ones({});
1146   auto input_h = torch::ones({});
1147   auto ref = m.run_method("forward", input_x, input_h);
1148 
1149   CompilationOptions options;
1150   mobile::Module bc = jitModuleToMobile(m, options);
1151   bc.run_method("forward", input_x, input_h);
1152   AT_ASSERT(input_x.equal(4 * torch::ones({})));
1153 
1154   auto buff = save_mobile_module_to_bytes(bc);
1155   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1156   auto input_x2 = 2 * torch::ones({});
1157   auto input_h2 = torch::ones({});
1158   m.run_method("forward", input_x2, input_h2);
1159   bc2.run_method("forward", input_x2, input_h2);
1160   AT_ASSERT(input_x2.equal(4 * torch::ones({})));
1161 }
1162 
1163 #endif // !defined(FB_XPLAT_BUILD)
1164 
1165 namespace {
1166 static auto reg =
1167     torch::class_<TorchBindFlatbufferTestStruct>(
1168         "_TorchScriptTesting",
1169         "_FlatbufferTest")
1170         .def(torch::init<>())
1171         .def("get", &TorchBindFlatbufferTestStruct::get)
1172         .def_pickle(
1173             // __getattr__
1174             [](const c10::intrusive_ptr<TorchBindFlatbufferTestStruct>& self)
__anon8587c2c70502(const c10::intrusive_ptr<TorchBindFlatbufferTestStruct>& self) 1175                 -> int64_t { return 0; },
1176             // __setattr__
__anon8587c2c70602(int64_t state) 1177             [](int64_t state) {
1178               return c10::make_intrusive<TorchBindFlatbufferTestStruct>();
1179             });
1180 
1181 } // namespace
1182 
TEST(FlatbufferTest,OperatorCacheDifferentiatesDefaultArgs)1183 TEST(FlatbufferTest, OperatorCacheDifferentiatesDefaultArgs) {
1184   // Create 3 methods:
1185   //
1186   // 1. forward() returns a tensor with dtype=torch.int64 (4)
1187   // 2. forward2() returns a tensor with dtype=torch.float32 (6)
1188   // 3. forward3() returns a tensor with dtype=torch.float32 but
1189   //    the dtype is inferred by the input tensor's dtype
1190   //
1191   // If caching works correctly, then the result from the full-jit
1192   // module and the lite module will be the same. Otherwise, it
1193   // will be different if we don't correctly ignore the cache
1194   // entry for an operator that has a different number of
1195   // arguments.
1196   Module m("m");
1197   m.define(R"(
1198     def forward(self):
1199       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
1200       return ret1.fill_(25)
1201   )");
1202   m.define(R"(
1203     def forward2(self):
1204       ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
1205       return ret1.fill_(32.0)
1206   )");
1207   m.define(R"(
1208     def forward3(self):
1209       ret1 = torch.new_empty(torch.zeros(10), [10])
1210       return ret1.fill_(12.0)
1211   )");
1212 
1213   std::vector<torch::jit::IValue> inputs;
1214   testLiteModuleCompareResultTensors(m, inputs, "forward");
1215   testLiteModuleCompareResultTensors(m, inputs, "forward2");
1216   testLiteModuleCompareResultTensors(m, inputs, "forward3");
1217 }
1218 
TEST(FlatbufferTest,OperatorSize1)1219 TEST(FlatbufferTest, OperatorSize1) {
1220   Module m("m");
1221   m.define(R"(
1222     def forward(self, input: Tensor, scale:float):
1223       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1224   )");
1225 
1226   CompilationOptions options;
1227   mobile::Module bc = jitModuleToMobile(m, options);
1228   const auto& func = bc.get_method("forward").function();
1229   ASSERT_EQ(
1230       func.get_code().operator_input_sizes_.size(),
1231       func.get_code().operators_.size());
1232 
1233   auto buff = save_mobile_module_to_bytes(bc);
1234   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1235   const auto& func2 = bc.get_method("forward").function();
1236   ASSERT_EQ(
1237       func2.get_code().operator_input_sizes_.size(),
1238       func2.get_code().operators_.size());
1239 }
1240 
TEST(FlatbufferTest,BoolAndDoubleList)1241 TEST(FlatbufferTest, BoolAndDoubleList) {
1242   Module m("m");
1243   c10::List<bool> boollist;
1244   boollist.push_back(false);
1245   IValue boollist_ival = boollist;
1246   IValue doublelist = std::vector<double>{2.0};
1247   m.register_attribute("bool_list", boollist_ival.type(), boollist_ival);
1248   m.register_attribute("double_list", doublelist.type(), doublelist);
1249 
1250   CompilationOptions options;
1251   mobile::Module bc = jitModuleToMobile(m, options);
1252   auto buff = save_mobile_module_to_bytes(bc);
1253   mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1254 
1255   // if the variables read are wrong type the conversion will raise exception
1256   auto boolval = bc2.attr("bool_list", {}).toBoolList().get(0);
1257   auto doubleval = bc2.attr("double_list", {}).toDoubleList().get(0);
1258 
1259   ASSERT_EQ(boolval, false);
1260   ASSERT_EQ(doubleval, 2.0);
1261 }
1262 
TEST(FlatbufferTest,OperatorTest2)1263 TEST(FlatbufferTest, OperatorTest2) { // NOLINT (use =delete in gtest)
1264   const std::vector<std::string> test_programs{
1265       // test invoking a method with default parameter
1266       R"(
1267       def test_func(self, x, b : int = 4):
1268         return self.foo + x + b
1269       )",
1270       // inner method call with default parameter (gets inlined)
1271       R"(
1272       def add_with_default_arg(self, x, b : int = 4):
1273         return self.foo + x + b
1274       def test_func(self, x):
1275         return self.add_with_default_arg(x)  # invoke method w/ default arg
1276       )",
1277       // simple method call
1278       R"(
1279       def test_func(self, x):
1280         b = 4
1281         return self.foo + x + b
1282       )",
1283   };
1284   for (const auto& test_program : test_programs) {
1285     Module m("m");
1286     m.register_parameter("foo", torch::ones({}), false);
1287     m.define(test_program);
1288 
1289     CompilationOptions options;
1290     mobile::Module bc = jitModuleToMobile(m, options);
1291     const auto& func = bc.get_method("test_func").function();
1292     ASSERT_EQ(
1293         func.get_code().operator_input_sizes_.size(),
1294         func.get_code().operators_.size());
1295 
1296     auto buff = save_mobile_module_to_bytes(bc);
1297     mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1298     const auto& func2 = bc.get_method("test_func").function();
1299     ASSERT_EQ(
1300         func2.get_code().operator_input_sizes_.size(),
1301         func2.get_code().operators_.size());
1302   }
1303 }
1304 
jitModuleFromBuffer(void * data,size_t size)1305 Module jitModuleFromBuffer(void* data, size_t size) {
1306   // Make a copy of the data so we can use the existing API, which takes
1307   // ownership. The `data` param might point into the middle of a buffer, so we
1308   // can't safely take ownership of it directly.
1309   // @nolint CLANGTIDY cppcoreguidelines-no-malloc
1310   std::shared_ptr<char> copy(static_cast<char*>(malloc(size)), free);
1311   memcpy(copy.get(), data, size);
1312 
1313   ExtraFilesMap extra_files;
1314   return parse_and_initialize_jit_module(std::move(copy), size, extra_files);
1315 }
1316 
TEST(TestSourceFlatbuffer,UpsampleNearest2d)1317 TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
1318   Module m("m");
1319   m.define(R"(
1320     def forward(self, input: Tensor, scale:float):
1321       return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1322   )");
1323 
1324   std::vector<IValue> inputs;
1325   inputs.emplace_back(torch::rand({1, 3, 128, 128}));
1326   inputs.emplace_back(at::Scalar(2.0));
1327   auto ref = m.forward(inputs);
1328 
1329   std::stringstream ss;
1330   m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true);
1331   auto mm = _load_for_mobile(ss);
1332   auto m2 = load(ss);
1333 
1334   auto res = m2.forward(inputs);
1335   auto resm = mm.forward(inputs);
1336 
1337   auto resd = res.toTensor();
1338   auto refd = ref.toTensor();
1339   auto resmd = resm.toTensor();
1340   ASSERT_TRUE(resd.equal(refd));
1341   ASSERT_TRUE(resmd.equal(refd));
1342 }
1343 
TEST(TestSourceFlatbuffer,CheckAttrAccess)1344 TEST(TestSourceFlatbuffer, CheckAttrAccess) {
1345   Module m("m");
1346   m.register_attribute("mobile_optimized", BoolType::get(), true);
1347   auto data = save_jit_module_to_bytes(m);
1348   Module m2 = jitModuleFromBuffer(data->data(), data->size());
1349   bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
1350   AT_ASSERT(mobile_optimized);
1351   mobile::Module m3 = parse_mobile_module(data->data(), data->size());
1352   mobile_optimized = m3.attr("mobile_optimized", false).toBool();
1353   AT_ASSERT(mobile_optimized);
1354 }
1355 
TEST(TestSourceFlatbuffer,MethodInvocation)1356 TEST(TestSourceFlatbuffer,
1357      MethodInvocation) { // NOLINT (use =delete in gtest)
1358   const std::vector<std::string> test_programs{
1359       // test invoking a method with default parameter
1360       R"(
1361       def test_func(self, x, b : int = 4):
1362         return self.foo + x + b
1363       )",
1364       // inner method call with default parameter (gets inlined)
1365       R"(
1366       def add_with_default_arg(self, x, b : int = 4):
1367         return self.foo + x + b
1368       def test_func(self, x):
1369         return self.add_with_default_arg(x)  # invoke method w/ default arg
1370       )",
1371       // simple method call
1372       R"(
1373       def test_func(self, x):
1374         b = 4
1375         return self.foo + x + b
1376       )",
1377   };
1378   for (const auto& test_program : test_programs) {
1379     Module m("m");
1380     m.register_parameter("foo", torch::ones({}), false);
1381     m.define(test_program);
1382 
1383     const int fortyTwo = 42; // (keep linter happy)
1384     auto minput = fortyTwo * torch::ones({});
1385     auto ref = m.run_method("test_func", minput);
1386 
1387     auto data = save_jit_module_to_bytes(m);
1388     Module m2 = jitModuleFromBuffer(data->data(), data->size());
1389     const auto& test_func = m2.get_method("test_func");
1390     IValue res;
1391     for (int i = 0; i < 3; ++i) {
1392       res = test_func({minput});
1393     }
1394     auto resd = res.toTensor().item<float>();
1395     auto refd = ref.toTensor().item<float>();
1396     AT_ASSERT(resd == refd);
1397 
1398     mobile::Module m3 = parse_mobile_module(data->data(), data->size());
1399     const auto& test_func3 = m3.get_method("test_func");
1400     for (int i = 0; i < 3; ++i) {
1401       res = test_func3({minput});
1402     }
1403     resd = res.toTensor().item<float>();
1404     refd = ref.toTensor().item<float>();
1405     AT_ASSERT(resd == refd);
1406   }
1407 }
1408 
1409 #if !defined FB_XPLAT_BUILD
1410 // The following test run in fbcode only
TEST(FlatbufferUpgraderTest,DivTensorV2)1411 TEST(FlatbufferUpgraderTest, DivTensorV2) {
1412   std::string filePath(__FILE__);
1413   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1414   test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl.ff");
1415   /*
1416   (('__torch__.MyModule.forward',
1417     (('instructions',
1418       (('STOREN', 1, 3),
1419        ('DROPR', 1, 0),
1420        ('LOAD', 2, 0),
1421        ('LOAD', 3, 0),
1422        ('OP', 0, 0),
1423        ('LOAD', 2, 0),
1424        ('LOAD', 3, 0),
1425        ('OP', 1, 0),
1426        ('MOVE', 2, 0),
1427        ('MOVE', 3, 0),
1428        ('OP', 2, 0),
1429        ('TUPLE_CONSTRUCT', 3, 0),
1430        ('RET', 0, 0))),
1431      ('operators',
1432       (('aten::div', 'Tensor'),
1433        ('aten::div', 'Tensor'),
1434        ('aten::div', 'Tensor'))),
1435      ('constants', ()),
1436      ('types', ()),
1437      ('register_size', 3))),)
1438 
1439   */
1440   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1441   auto intrsuction_list =
1442       m_module.get_method("forward").function().get_code().instructions_;
1443   uint64_t number_of_call_instruction = 0;
1444   for (auto& instruction : intrsuction_list) {
1445     number_of_call_instruction += (instruction.op == OpCode::CALL);
1446   }
1447   // 3 operators will use upgrader
1448   ASSERT_EQ(number_of_call_instruction, 3);
1449 
1450   std::vector<IValue> inputs = {
1451       IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1452   auto actual_output = m_module.forward(inputs);
1453   auto expect_output = 2.0 * torch::ones({1});
1454   auto actual_output_list = actual_output.toTuple()->elements();
1455   ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output));
1456 }
1457 
TEST(FlatbufferUpgraderTest,DivTensorOutV2)1458 TEST(FlatbufferUpgraderTest, DivTensorOutV2) {
1459   std::string filePath(__FILE__);
1460   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1461   test_model_file.append(
1462       "upgrader_models/test_versioned_div_tensor_out_v2.ptl.ff");
1463   /*
1464   (('__torch__.MyModule.forward',
1465     (('instructions',
1466       (('STOREN', 1, 4),
1467        ('DROPR', 1, 0),
1468        ('MOVE', 2, 0),
1469        ('MOVE', 3, 0),
1470        ('MOVE', 4, 0),
1471        ('OP', 0, 0),
1472        ('RET', 0, 0))),
1473      ('operators', (('aten::div', 'out'),)),
1474      ('constants', ()),
1475      ('types', ()),
1476      ('register_size', 4))),)
1477   */
1478   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1479 
1480   auto intrsuction_list =
1481       m_module.get_method("forward").function().get_code().instructions_;
1482   uint64_t number_of_call_instruction = 0;
1483   for (auto& instruction : intrsuction_list) {
1484     number_of_call_instruction += (instruction.op == OpCode::CALL);
1485   }
1486   // One operator will use upgrader
1487   ASSERT_EQ(number_of_call_instruction, 1);
1488 
1489   std::vector<IValue> inputs{
1490       IValue(6 * torch::ones({1})),
1491       IValue(3 * torch::ones({1})),
1492       IValue(torch::empty({1}))};
1493   m_module.forward(inputs);
1494   auto expect_output = 2.0 * torch::ones({1});
1495   auto actual_output = inputs[2].toTensor();
1496   // The out argument will be overwritten with the output
1497   ASSERT_TRUE(actual_output.equal(expect_output));
1498 }
1499 
TEST(FlatbufferUpgraderTest,DivTensorInplaceV2)1500 TEST(FlatbufferUpgraderTest, DivTensorInplaceV2) {
1501   std::string filePath(__FILE__);
1502   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1503   test_model_file.append(
1504       "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl.ff");
1505   /*
1506   (('__torch__.MyModule.forward',
1507     (('instructions',
1508       (('STOREN', 1, 3),
1509        ('DROPR', 1, 0),
1510        ('MOVE', 2, 0),
1511        ('MOVE', 3, 0),
1512        ('OP', 0, 0),
1513        ('RET', 0, 0))),
1514      ('operators', (('aten::div_', 'Tensor'),)),
1515      ('constants', ()),
1516      ('types', ()),
1517      ('register_size', 3))),)
1518   */
1519   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1520 
1521   auto intrsuction_list =
1522       m_module.get_method("forward").function().get_code().instructions_;
1523   uint64_t number_of_call_instruction = 0;
1524   for (auto& instruction : intrsuction_list) {
1525     number_of_call_instruction += (instruction.op == OpCode::CALL);
1526   }
1527   // One operator will use upgrader
1528   ASSERT_EQ(number_of_call_instruction, 1);
1529 
1530   std::vector<IValue> inputs{
1531       IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1532   m_module.forward(inputs);
1533   auto expect_output = 2.0 * torch::ones({1});
1534   auto actual_output = inputs[0].toTensor();
1535   // The out argument will be overwritten with the output
1536   ASSERT_TRUE(actual_output.equal(expect_output));
1537 }
1538 
TEST(FlatbufferUpgraderTest,DivScalarFloatV2)1539 TEST(FlatbufferUpgraderTest, DivScalarFloatV2) {
1540   std::string filePath(__FILE__);
1541   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1542   test_model_file.append(
1543       "upgrader_models/test_versioned_div_scalar_float_v2.ptl.ff");
1544   /*
1545   (('__torch__.MyModuleFloat.forward',
1546     (('instructions',
1547     (('STOREN', 1, 3),
1548     ('DROPR', 1, 0),
1549     ('MOVE', 2, 0),
1550     ('MOVE', 3, 0),
1551     ('OP', 0, 0),
1552     ('RET', 0, 0))),
1553     ('operators', (('aten::div', 'Scalar'),)),
1554     ('constants', ()),
1555     ('types', ()),
1556     ('register_size', 3))),)
1557   */
1558 
1559   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1560 
1561   auto intrsuction_list =
1562       m_module.get_method("forward").function().get_code().instructions_;
1563   uint64_t number_of_call_instruction = 0;
1564   for (auto& instruction : intrsuction_list) {
1565     number_of_call_instruction += (instruction.op == OpCode::CALL);
1566   }
1567   // One operator will use upgrader
1568   ASSERT_EQ(number_of_call_instruction, 1);
1569 
1570   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1571   auto output = m_module.forward(inputs);
1572   auto expect_output = 2.0 * torch::ones({1});
1573   auto actual_output = output.toTensor();
1574 
1575   // The out argument will be overwritten with the output
1576   ASSERT_TRUE(actual_output.equal(expect_output));
1577 }
1578 
TEST(FlatbufferUpgraderTest,DivScalarReciprocalFloatV2)1579 TEST(FlatbufferUpgraderTest, DivScalarReciprocalFloatV2) {
1580   std::string filePath(__FILE__);
1581   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1582   test_model_file.append(
1583       "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl.ff");
1584   /*
1585   (('__torch__.MyModuleFloat.forward',
1586     (('instructions',
1587       (('STOREN', 1, 3),
1588       ('DROPR', 1, 0),
1589       ('MOVE', 2, 0),
1590       ('OP', 0, 0),
1591       ('MOVE', 3, 0),
1592       ('OP', 1, 0),
1593       ('RET', 0, 0))),
1594     ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1595     ('constants', ()),
1596     ('types', ()),
1597     ('register_size', 3))),)
1598   */
1599   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1600 
1601   auto intrsuction_list =
1602       m_module.get_method("forward").function().get_code().instructions_;
1603   uint64_t number_of_call_instruction = 0;
1604   for (auto& instruction : intrsuction_list) {
1605     number_of_call_instruction += (instruction.op == OpCode::CALL);
1606   }
1607   // No operator will use upgrader
1608   ASSERT_EQ(number_of_call_instruction, 0);
1609 
1610   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1611   auto output = m_module.forward(inputs);
1612   auto expect_output = 0.5 * torch::ones({1});
1613   auto actual_output = output.toTensor();
1614   // The out argument will be overwritten with the output
1615   ASSERT_TRUE(actual_output.equal(expect_output));
1616 }
1617 
TEST(FlatbufferUpgraderTest,DivScalarReciprocalIntV2)1618 TEST(FlatbufferUpgraderTest, DivScalarReciprocalIntV2) {
1619   std::string filePath(__FILE__);
1620   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1621   test_model_file.append(
1622       "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl.ff");
1623   /*
1624   (('__torch__.MyModuleInt.forward',
1625   (('instructions',
1626     (('STOREN', 1, 3),
1627      ('DROPR', 1, 0),
1628      ('MOVE', 2, 0),
1629      ('OP', 0, 0),
1630      ('MOVE', 3, 0),
1631      ('OP', 1, 0),
1632      ('RET', 0, 0))),
1633    ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1634    ('constants', ()),
1635    ('types', ()),
1636    ('register_size', 3))),)
1637   */
1638   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1639 
1640   auto intrsuction_list =
1641       m_module.get_method("forward").function().get_code().instructions_;
1642   uint64_t number_of_call_instruction = 0;
1643   for (auto& instruction : intrsuction_list) {
1644     number_of_call_instruction += (instruction.op == OpCode::CALL);
1645   }
1646   // No operator will use upgrader
1647   ASSERT_EQ(number_of_call_instruction, 0);
1648 
1649   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1650   auto output = m_module.forward(inputs);
1651   auto expect_output = 0.5 * torch::ones({1});
1652   auto actual_output = output.toTensor();
1653 
1654   // The out argument will be overwritten with the output
1655   ASSERT_TRUE(actual_output.equal(expect_output));
1656 }
1657 
TEST(FlatbufferUpgraderTest,DivScalarScalarV2)1658 TEST(FlatbufferUpgraderTest, DivScalarScalarV2) {
1659   std::string filePath(__FILE__);
1660   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1661   test_model_file.append(
1662       "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl.ff");
1663   /*
1664   (('__torch__.MyModule.forward',
1665     (('instructions',
1666       (('STOREN', 1, 5),
1667       ('DROPR', 1, 0),
1668       ('LOAD', 2, 0),
1669       ('LOAD', 3, 0),
1670       ('OP', 0, 0),
1671       ('MOVE', 2, 0),
1672       ('LOAD', 4, 0),
1673       ('OP', 1, 0),
1674       ('LOAD', 3, 0),
1675       ('MOVE', 4, 0),
1676       ('OP', 2, 0),
1677       ('MOVE', 3, 0),
1678       ('MOVE', 5, 0),
1679       ('OP', 3, 0),
1680       ('TUPLE_CONSTRUCT', 4, 0),
1681       ('RET', 0, 0))),
1682     ('operators',
1683       (('aten::div', ''),
1684       ('aten::div', 'float'),
1685       ('aten::div', ''),
1686       ('aten::div', 'int'))),
1687     ('constants', ()),
1688     ('types', ()),
1689     ('register_size', 5))),)
1690   */
1691   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1692   auto intrsuction_list =
1693       m_module.get_method("forward").function().get_code().instructions_;
1694   uint64_t number_of_call_instruction = 0;
1695   for (auto& instruction : intrsuction_list) {
1696     number_of_call_instruction += (instruction.op == OpCode::CALL);
1697   }
1698   // No operator will use upgrader
1699   ASSERT_EQ(number_of_call_instruction, 0);
1700 
1701   std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)};
1702   auto output = m_module.forward(inputs);
1703   auto output_list = output.toTupleRef().elements();
1704   auto expect_output = std::vector<IValue>(
1705       {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)});
1706   // auto actual_output = output.toTensor();
1707   for (size_t i = 0; i < expect_output.size(); i++) {
1708     ASSERT_EQ(output_list[i], expect_output[i]);
1709   }
1710 }
1711 
TEST(FlatbufferUpgraderTest,DivScalarIntV2)1712 TEST(FlatbufferUpgraderTest, DivScalarIntV2) {
1713   std::string filePath(__FILE__);
1714   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1715   test_model_file.append(
1716       "upgrader_models/test_versioned_div_scalar_int_v2.ptl.ff");
1717   /*
1718   (('__torch__.MyModuleInt.forward',
1719     (('instructions',
1720       (('STOREN', 1, 3),
1721       ('DROPR', 1, 0),
1722       ('MOVE', 2, 0),
1723       ('MOVE', 3, 0),
1724       ('OP', 0, 0),
1725       ('RET', 0, 0))),
1726     ('operators', (('aten::div', 'Scalar'),)),
1727     ('constants', ()),
1728     ('types', ()),
1729     ('register_size', 3))),)
1730   */
1731   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1732 
1733   auto intrsuction_list =
1734       m_module.get_method("forward").function().get_code().instructions_;
1735   uint64_t number_of_call_instruction = 0;
1736   for (auto& instruction : intrsuction_list) {
1737     number_of_call_instruction += (instruction.op == OpCode::CALL);
1738   }
1739   // One operator will use upgrader
1740   ASSERT_EQ(number_of_call_instruction, 1);
1741 
1742   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1743   auto output = m_module.forward(inputs);
1744   auto expect_output = 2.0 * torch::ones({1});
1745   auto actual_output = output.toTensor();
1746 
1747   // The out argument will be overwritten with the output
1748   ASSERT_TRUE(actual_output.equal(expect_output));
1749 }
1750 
TEST(FlatbufferUpgraderTest,DivScalarInplaceFloatV2)1751 TEST(FlatbufferUpgraderTest, DivScalarInplaceFloatV2) {
1752   std::string filePath(__FILE__);
1753   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1754   test_model_file.append(
1755       "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl.ff");
1756   /*
1757   (('__torch__.MyModuleFloat.forward',
1758     (('instructions',
1759       (('STOREN', 1, 3),
1760       ('DROPR', 1, 0),
1761       ('MOVE', 2, 0),
1762       ('MOVE', 3, 0),
1763       ('OP', 0, 0),
1764       ('RET', 0, 0))),
1765     ('operators', (('aten::div_', 'Scalar'),)),
1766     ('constants', ()),
1767     ('types', ()),
1768     ('register_size', 3))),)
1769   */
1770 
1771   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1772 
1773   auto intrsuction_list =
1774       m_module.get_method("forward").function().get_code().instructions_;
1775   uint64_t number_of_call_instruction = 0;
1776   for (auto& instruction : intrsuction_list) {
1777     number_of_call_instruction += (instruction.op == OpCode::CALL);
1778   }
1779   // One operator will use upgrader
1780   ASSERT_EQ(number_of_call_instruction, 1);
1781 
1782   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1783   auto output = m_module.forward(inputs);
1784   auto expect_output = 2.0 * torch::ones({1});
1785   auto actual_output = output.toTensor();
1786 
1787   // The out argument will be overwritten with the output
1788   ASSERT_TRUE(actual_output.equal(expect_output));
1789 }
1790 
TEST(FlatbufferUpgraderTest,DivScalarInplaceIntV2)1791 TEST(FlatbufferUpgraderTest, DivScalarInplaceIntV2) {
1792   std::string filePath(__FILE__);
1793   auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1794   test_model_file.append(
1795       "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl.ff");
1796   /*
1797   (('__torch__.MyModuleInt.forward',
1798     (('instructions',
1799       (('STOREN', 1, 3),
1800        ('DROPR', 1, 0),
1801        ('MOVE', 2, 0),
1802        ('MOVE', 3, 0),
1803        ('OP', 0, 0),
1804        ('RET', 0, 0))),
1805      ('operators', (('aten::div_', 'Scalar'),)),
1806      ('constants', ()),
1807      ('types', ()),
1808      ('register_size', 3))),)
1809   */
1810 
1811   mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1812 
1813   auto intrsuction_list =
1814       m_module.get_method("forward").function().get_code().instructions_;
1815   uint64_t number_of_call_instruction = 0;
1816   for (auto& instruction : intrsuction_list) {
1817     number_of_call_instruction += (instruction.op == OpCode::CALL);
1818   }
1819   // One operator will use upgrader
1820   ASSERT_EQ(number_of_call_instruction, 1);
1821 
1822   std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1823   auto output = m_module.forward(inputs);
1824   auto expect_output = 2.0 * torch::ones({1});
1825   auto actual_output = output.toTensor();
1826 
1827   // The out argument will be overwritten with the output
1828   ASSERT_TRUE(actual_output.equal(expect_output));
1829 }
1830 
1831 #endif // !defined(FB_XPLAT_BUILD)
1832 
1833 //
1834 // Tests that need access to internal flatbuffers types/functions.
1835 // Do not add any other tests after this section.
1836 //
1837 
1838 } // namespace jit
1839 } // namespace torch
1840 namespace torch {
1841 namespace jit {
1842 
1843 /**
1844  * An Allocator that can only deallocate (using delete []), counting
1845  * the number of times that it has been asked to deallocate.
1846  */
1847 class TestAllocator : public flatbuffers::Allocator {
1848  public:
1849   /**
1850    * *deallocate_call_count will be incremented whenever deallocate() is called.
1851    */
TestAllocator(int * deallocate_call_count)1852   explicit TestAllocator(int* deallocate_call_count)
1853       : deallocate_call_count_(deallocate_call_count) {}
1854 
deallocate(uint8_t * p,size_t)1855   void deallocate(uint8_t* p, size_t /*size*/) override {
1856     *deallocate_call_count_ += 1;
1857     delete[] p;
1858   }
1859 
allocate(size_t)1860   uint8_t* allocate(size_t) override {
1861     TORCH_CHECK(false, "allocate() should not be called");
1862   }
reallocate_downward(uint8_t *,size_t,size_t,size_t,size_t)1863   uint8_t* reallocate_downward(uint8_t*, size_t, size_t, size_t, size_t)
1864       override {
1865     TORCH_CHECK(false, "reallocate_downward() should not be called");
1866   }
1867 
1868  private:
1869   int* deallocate_call_count_;
1870 };
1871 
1872 /// Provides access to DetachedBuffer::destroy().
1873 struct DetachedBufferTestingFriend {
1874   /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer.
1875   /// A copy of similar code in flatbuffer_serializer.cpp.
make_unique_detached_buffertorch::jit::DetachedBufferTestingFriend1876   static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer(
1877       DetachedBuffer* buf) {
1878     return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy);
1879   }
1880 };
1881 
TEST(FlatbufferTest,DetachedBufferSmoke)1882 TEST(FlatbufferTest, DetachedBufferSmoke) {
1883   // Use a custom Allocator to watch the lifecycle of a
1884   // flatbuffers::DetachedBuffer.
1885   int deallocate_call_count = 0;
1886   TestAllocator alloc(&deallocate_call_count);
1887 
1888   // Data for the buffer. TestAllocator will free it with `delete []`.
1889   constexpr size_t data_size = 4;
1890   uint8_t* data = new uint8_t[data_size];
1891 
1892   // An internal buffer on the stack that owns the data.
1893   flatbuffers::DetachedBuffer fb_buf_local(
1894       &alloc, /*own_allocator=*/false, data, data_size, data, data_size);
1895   EXPECT_EQ(fb_buf_local.data(), data);
1896   EXPECT_EQ(fb_buf_local.size(), data_size);
1897 
1898   // Mimic the code inside save_mobile_module_to_bytes by transferring ownership
1899   // to a heap object.
1900   auto fb_buf_ptr = new flatbuffers::DetachedBuffer(std::move(fb_buf_local));
1901   // The data should not have been deleted yet.
1902   EXPECT_EQ(deallocate_call_count, 0);
1903   // The new object points to the data.
1904   EXPECT_EQ(fb_buf_ptr->data(), data);
1905   EXPECT_EQ(fb_buf_ptr->size(), data_size);
1906   // The old object points to nothing.
1907   // @lint-ignore CLANGTIDY bugprone-use-after-move
1908   EXPECT_EQ(fb_buf_local.data(), nullptr);
1909   // @lint-ignore CLANGTIDY bugprone-use-after-move
1910   EXPECT_EQ(fb_buf_local.size(), 0);
1911 
1912   // The top-level torch::jit::DetachedBuffer.
1913   auto wrapped_buf =
1914       new DetachedBuffer(fb_buf_ptr->data(), fb_buf_ptr->size(), fb_buf_ptr);
1915   EXPECT_EQ(wrapped_buf->data(), data);
1916   EXPECT_EQ(wrapped_buf->size(), data_size);
1917 
1918   // The unique_ptr that owns the torch::jit::DetachedBuffer and its contents.
1919   {
1920     DetachedBuffer::UniqueDetachedBuffer unique_buf =
1921         DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf);
1922     EXPECT_EQ(unique_buf->data(), data);
1923     EXPECT_EQ(unique_buf->size(), data_size);
1924 
1925     // The data should not have been deleted yet.
1926     EXPECT_EQ(deallocate_call_count, 0);
1927   }
1928 
1929   // Now that the unique_ptr is out of scope, the data should have been deleted.
1930   EXPECT_EQ(deallocate_call_count, 1);
1931 }
1932 
TEST(FlatbufferTest,DetachedBufferNullOwner)1933 TEST(FlatbufferTest, DetachedBufferNullOwner) {
1934   // a torch::jit::DetachedBuffer with a null internal owner.
1935   std::vector<uint8_t> data(4);
1936   auto wrapped_buf = new DetachedBuffer(data.data(), data.size());
1937 
1938   // A unique_ptr that owns the torch::jit::DetachedBuffer and its contents.
1939   {
1940     DetachedBuffer::UniqueDetachedBuffer unique_buf =
1941         DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf);
1942     EXPECT_EQ(unique_buf->data(), data.data());
1943     EXPECT_EQ(unique_buf->size(), data.size());
1944   }
1945 
1946   // The DetachedBuffer should have been destroyed when the UniqueDetachedBuffer
1947   // went out of scope. If we didn't crash or get any ASAN warnings, we should
1948   // be good.
1949 }
1950 
1951 //
1952 // Do not add tests here unless they require flatbuffers types. See comment at
1953 // the beginning of this section.
1954 //
1955 
1956 } // namespace jit
1957 } // namespace torch
1958