1 #include <test/cpp/jit/test_utils.h>
2
3 #include <gtest/gtest.h>
4
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/frontend/resolver.h>
9 #include <torch/csrc/jit/mobile/compatibility/backport.h>
10 #include <torch/csrc/jit/mobile/compatibility/backport_manager.h>
11 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
12 #include <torch/csrc/jit/mobile/compatibility/runtime_compatibility.h>
13 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
14 #include <torch/csrc/jit/mobile/import.h>
15 #include <torch/csrc/jit/mobile/interpreter.h>
16 #include <torch/csrc/jit/mobile/module.h>
17 #include <torch/csrc/jit/mobile/parse_bytecode.h>
18 #include <torch/csrc/jit/mobile/parse_operators.h>
19 #include <torch/csrc/jit/serialization/export.h>
20 #include <torch/csrc/jit/serialization/export_bytecode.h>
21 #include <torch/csrc/jit/serialization/flatbuffer_serializer.h>
22 #include <torch/csrc/jit/serialization/flatbuffer_serializer_jit.h>
23 #include <torch/csrc/jit/serialization/import.h>
24 #include <torch/custom_class.h>
25 #include <torch/torch.h>
26
27 #include <caffe2/serialize/versions.h>
28 #include <torch/csrc/jit/serialization/import_export_functions.h>
29 #include <unordered_set>
30
31 #if defined(FB_XPLAT_BUILD) || defined(FBCODE_CAFFE2)
32 #include <torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h> // NOLINT
33 namespace flatbuffers = flatbuffers_fbsource;
34 #define FLATBUFFERS_MAX_ALIGNMENT FLATBUFFERS_FBSOURCE_MAX_ALIGNMENT
35 #else
36 #include <torch/csrc/jit/serialization/mobile_bytecode_generated.h> // NOLINT
37 #endif
38 // Tests go in torch::jit
39 namespace torch {
40 namespace jit {
41
42 namespace {
parse_mobile_module(void * data,size_t size,bool should_copy_tensor_memory=false)43 mobile::Module parse_mobile_module(
44 void* data,
45 size_t size,
46 bool should_copy_tensor_memory = false) {
47 return parse_and_initialize_mobile_module(
48 static_cast<char*>(data),
49 size,
50 /*device=*/std::nullopt,
51 /*extra_files=*/nullptr,
52 should_copy_tensor_memory);
53 }
54 } // namespace
55
TEST(FlatbufferTest,LoadMalformedModule)56 TEST(FlatbufferTest, LoadMalformedModule) {
57 // Manually create some data with Flatbuffer header.
58 std::stringstream bad_data;
59 bad_data << "PK\x03\x04PTMF\x00\x00"
60 << "*}NV\xb3\xfa\xdf\x00pa";
61
62 // Loading module from it should throw an exception.
63 // Check guard at parse_and_initialize_mobile_module_for_jit.
64 ASSERT_THROWS_WITH_MESSAGE(
65 torch::jit::load(bad_data), "Malformed Flatbuffer module");
66
67 // Check guard at parse_and_initialize_mobile_module.
68 ASSERT_THROWS_WITH_MESSAGE(
69 parse_mobile_module(bad_data.str().data(), bad_data.str().size()),
70 "Malformed Flatbuffer module");
71 }
72
TEST(FlatbufferTest,UpsampleNearest2d)73 TEST(FlatbufferTest, UpsampleNearest2d) {
74 Module m("m");
75 m.define(R"(
76 def forward(self, input: Tensor, scale:float):
77 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
78 )");
79
80 std::vector<IValue> inputs;
81 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
82 inputs.emplace_back(at::Scalar(2.0));
83 auto ref = m.forward(inputs);
84
85 CompilationOptions options;
86 mobile::Module bc = jitModuleToMobile(m, options);
87 IValue res;
88 res = bc.forward(inputs);
89
90 auto resd = res.toTensor();
91 auto refd = ref.toTensor();
92 ASSERT_TRUE(resd.equal(refd));
93
94 auto buff = save_mobile_module_to_bytes(bc);
95 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
96 auto res2 = bc2.forward(inputs);
97 auto resd2 = res2.toTensor();
98 ASSERT_TRUE(resd2.equal(refd));
99 }
100
TEST(FlatbufferTest,UpsampleNearest2dWithCopyTensorMemory)101 TEST(FlatbufferTest, UpsampleNearest2dWithCopyTensorMemory) {
102 Module m("m");
103 m.define(R"(
104 def forward(self, input: Tensor, scale:float):
105 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
106 )");
107
108 std::vector<IValue> inputs;
109 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
110 inputs.emplace_back(at::Scalar(2.0));
111 auto ref = m.forward(inputs);
112
113 CompilationOptions options;
114 mobile::Module bc = jitModuleToMobile(m, options);
115 IValue res;
116 res = bc.forward(inputs);
117
118 auto resd = res.toTensor();
119 auto refd = ref.toTensor();
120 ASSERT_TRUE(resd.equal(refd));
121
122 auto buff = save_mobile_module_to_bytes(bc);
123 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
124
125 auto res2 = bc2.forward(inputs);
126 auto resd2 = res2.toTensor();
127 ASSERT_TRUE(resd2.equal(refd));
128 }
129
TEST(FlatbufferTest,CheckAttrAccess)130 TEST(FlatbufferTest, CheckAttrAccess) {
131 Module m("m");
132 m.register_attribute("mobile_optimized", BoolType::get(), true);
133
134 CompilationOptions options;
135 mobile::Module bc = jitModuleToMobile(m, options);
136 bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
137
138 AT_ASSERT(mobile_optimized);
139 m.setattr("mobile_optimized", false);
140 bc = jitModuleToMobile(m, options);
141 mobile_optimized = bc.attr("mobile_optimized", false).toBool();
142
143 AT_ASSERT(!mobile_optimized);
144
145 auto buff = save_mobile_module_to_bytes(bc);
146 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
147 auto mobile_optimized2 = bc2.attr("mobile_optimized", false).toBool();
148 AT_ASSERT(!mobile_optimized2);
149 }
150
TEST(FlatbufferTest,MethodInvocation)151 TEST(FlatbufferTest, MethodInvocation) { // NOLINT (use =delete in gtest)
152 const std::vector<std::string> test_programs{
153 // test invoking a method with default parameter
154 R"(
155 def test_func(self, x, b : int = 4):
156 return self.foo + x + b
157 )",
158 // inner method call with default parameter (gets inlined)
159 R"(
160 def add_with_default_arg(self, x, b : int = 4):
161 return self.foo + x + b
162 def test_func(self, x):
163 return self.add_with_default_arg(x) # invoke method w/ default arg
164 )",
165 // simple method call
166 R"(
167 def test_func(self, x):
168 b = 4
169 return self.foo + x + b
170 )",
171 };
172 for (const auto& test_program : test_programs) {
173 Module m("m");
174 m.register_parameter("foo", torch::ones({}), false);
175 m.define(test_program);
176
177 const int fortyTwo = 42; // (keep linter happy)
178 auto minput = fortyTwo * torch::ones({});
179 auto ref = m.run_method("test_func", minput);
180
181 CompilationOptions options;
182 mobile::Module bc = jitModuleToMobile(m, options);
183 const auto& test_func = bc.get_method("test_func");
184 IValue res;
185 for (int i = 0; i < 3; ++i) {
186 res = test_func({minput});
187 }
188
189 auto resd = res.toTensor().item<float>();
190 auto refd = ref.toTensor().item<float>();
191 AT_ASSERT(resd == refd);
192
193 auto buff = save_mobile_module_to_bytes(bc);
194 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
195 const auto& test_func2 = bc2.get_method("test_func");
196 IValue res2;
197 for (int i = 0; i < 3; ++i) {
198 res2 = test_func2({minput});
199 }
200 auto resd2 = res2.toTensor().item<float>();
201 AT_ASSERT(resd2 == refd);
202 }
203 }
204
205 #if !defined(FB_XPLAT_BUILD)
TEST(FlatbufferTest,FlatbufferBackPortTest)206 TEST(FlatbufferTest, FlatbufferBackPortTest) {
207 Module m("m");
208 m.define(R"(
209 def forward(self, input: Tensor, scale:float):
210 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
211 )");
212 std::stringstream ss;
213 m._save_for_mobile(ss, {}, false, true);
214
215 std::stringstream oss;
216 bool backPortSuccess = _backport_for_mobile(ss, oss, 5);
217 ASSERT_TRUE(backPortSuccess);
218 }
219 #endif // !defined(FB_XPLAT_BUILD)
220
TEST(FlatbufferTest,ExtraFiles)221 TEST(FlatbufferTest, ExtraFiles) {
222 const auto script = R"JIT(
223 def forward(self):
224 x = torch.rand(5, 5)
225 x = x.mm(x)
226 return x
227 )JIT";
228
229 auto module =
230 std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
231 module->define(script);
232 std::ostringstream oss;
233 std::unordered_map<std::string, std::string> extra_files;
234 extra_files["metadata.json"] = "abc";
235 extra_files["mobile_info.json"] = "{\"key\": 23}";
236
237 std::unordered_map<std::string, std::string> loaded_extra_files;
238 std::stringstream ss;
239 module->_save_for_mobile(ss, extra_files, true, /*use_flatbuffer=*/true);
240
241 loaded_extra_files["metadata.json"] = "";
242 auto mobile_module = _load_for_mobile(ss, std::nullopt, loaded_extra_files);
243
244 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
245 ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
246
247 // load it twice using the same stream
248 auto mobile_module2 = _load_for_mobile(ss, std::nullopt, loaded_extra_files);
249
250 ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
251 ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
252
253 // Test if flatbuffer does not require any explicit key entries mapping in the
254 // extra file map.
255 std::unordered_map<std::string, std::string>
256 loaded_extra_files_without_explicit_entries;
257 auto mobile_module3 = _load_for_mobile(
258 ss,
259 std::nullopt,
260 loaded_extra_files_without_explicit_entries,
261 MobileModuleLoadOptions::PARSE_ALL_EXTRA_FILE_MAPS);
262
263 ASSERT_EQ(
264 loaded_extra_files_without_explicit_entries["metadata.json"], "abc");
265 ASSERT_EQ(
266 loaded_extra_files_without_explicit_entries["mobile_info.json"],
267 "{\"key\": 23}");
268 }
269
TEST(FlatbufferTest,Conv)270 TEST(FlatbufferTest, Conv) {
271 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
272 if (s && strcmp(s, "1") == 0)
273 return;
274
275 std::vector<torch::jit::IValue> inputs;
276
277 Module m("m");
278 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
279 m.register_parameter("bias", torch::ones({20}), false);
280 m.define(R"(
281 def forward(self, input):
282 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
283 )");
284
285 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
286 inputs.push_back(torch::ones({1, 1, 28, 28}));
287
288 auto outputref = m.forward(inputs).toTensor();
289
290 CompilationOptions options;
291 mobile::Module bc = jitModuleToMobile(m, options);
292 IValue res;
293 for (int i = 0; i < 3; ++i) {
294 res = bc.get_method("forward")(inputs);
295 }
296 auto output = res.toTensor();
297 AT_ASSERT(outputref.dim() == output.dim());
298 AT_ASSERT(
299 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
300
301 auto buff = save_mobile_module_to_bytes(bc);
302 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
303 for (int i = 0; i < 3; ++i) {
304 res = bc2.get_method("forward")(inputs);
305 }
306 output = res.toTensor();
307 AT_ASSERT(outputref.dim() == output.dim());
308 AT_ASSERT(
309 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
310 }
311
TEST(FlatbufferTest,ConvWithCopyTensorMemory)312 TEST(FlatbufferTest, ConvWithCopyTensorMemory) {
313 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
314 if (s && strcmp(s, "1") == 0)
315 return;
316
317 std::vector<torch::jit::IValue> inputs;
318
319 Module m("m");
320 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
321 m.register_parameter("bias", torch::ones({20}), false);
322 m.define(R"(
323 def forward(self, input):
324 return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
325 )");
326
327 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
328 inputs.push_back(torch::ones({1, 1, 28, 28}));
329
330 auto outputref = m.forward(inputs).toTensor();
331
332 CompilationOptions options;
333 mobile::Module bc = jitModuleToMobile(m, options);
334 IValue res;
335 for (int i = 0; i < 3; ++i) {
336 res = bc.get_method("forward")(inputs);
337 }
338 auto output = res.toTensor();
339 AT_ASSERT(outputref.dim() == output.dim());
340 AT_ASSERT(
341 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
342
343 auto buff = save_mobile_module_to_bytes(bc);
344 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
345
346 for (int i = 0; i < 3; ++i) {
347 res = bc2.get_method("forward")(inputs);
348 }
349 output = res.toTensor();
350 AT_ASSERT(outputref.dim() == output.dim());
351 AT_ASSERT(
352 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
353 }
354
TEST(FlatbufferTest,Inline)355 TEST(FlatbufferTest, Inline) {
356 Module m("m");
357 m.define(R"JIT(
358 def foo1(self, x):
359 return x + 1
360
361 def foo2(self, x):
362 return self.foo1(x) + 2
363
364 def foo3(self, x):
365 return self.foo2(x) + 3
366 )JIT");
367 CompilationOptions options;
368 mobile::Module bc = jitModuleToMobile(m, options);
369 std::vector<torch::jit::IValue> inputs({torch::ones({})});
370 auto output = bc.get_method("foo3")(inputs);
371 AT_ASSERT(output.toTensor().item<float>() == 7.0);
372
373 auto buff = save_mobile_module_to_bytes(bc);
374 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
375 std::vector<torch::jit::IValue> inputs2({torch::ones({})});
376 output = bc2.get_method("foo3")(inputs2);
377 AT_ASSERT(output.toTensor().item<float>() == 7.0);
378 }
379
TEST(FlatbufferTest,InlineWithCopyTensorMemory)380 TEST(FlatbufferTest, InlineWithCopyTensorMemory) {
381 Module m("m");
382 m.define(R"JIT(
383 def foo1(self, x):
384 return x + 1
385
386 def foo2(self, x):
387 return self.foo1(x) + 2
388
389 def foo3(self, x):
390 return self.foo2(x) + 3
391 )JIT");
392 CompilationOptions options;
393 mobile::Module bc = jitModuleToMobile(m, options);
394 std::vector<torch::jit::IValue> inputs({torch::ones({})});
395 auto output = bc.get_method("foo3")(inputs);
396 AT_ASSERT(output.toTensor().item<float>() == 7.0);
397
398 auto buff = save_mobile_module_to_bytes(bc);
399 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size(), true);
400 std::vector<torch::jit::IValue> inputs2({torch::ones({})});
401 output = bc2.get_method("foo3")(inputs2);
402 AT_ASSERT(output.toTensor().item<float>() == 7.0);
403 }
404
TEST(FlatbufferTest,Tuple)405 TEST(FlatbufferTest, Tuple) {
406 Module m("m");
407 m.define(R"JIT(
408 def foo(self, x):
409 return (1, 2, x + 3)
410
411 def forward(self, x):
412 tuple = self.foo(x)
413 return tuple
414 )JIT");
415 CompilationOptions options;
416 mobile::Module bc = jitModuleToMobile(m, options);
417 std::vector<torch::jit::IValue> inputs({torch::ones({})});
418 auto output = bc.get_method("forward")(inputs);
419 AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
420
421 auto buff = save_mobile_module_to_bytes(bc);
422 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
423 output = bc2.get_method("forward")(inputs);
424 AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
425 }
426
TEST(FlatbufferTest,Dict)427 TEST(FlatbufferTest, Dict) {
428 Module m("m");
429 m.define(R"JIT(
430 def foo(self, x):
431 return {"result": x + 1}
432
433 def forward(self, x):
434 d = self.foo(x)
435 return d
436 )JIT");
437 CompilationOptions options;
438 mobile::Module bc = jitModuleToMobile(m, options);
439 std::vector<torch::jit::IValue> inputs({torch::ones({})});
440 auto output = bc.get_method("forward")(inputs);
441 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
442
443 auto buff = save_mobile_module_to_bytes(bc);
444 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
445 output = bc2.get_method("forward")(inputs);
446 AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
447 }
448
TEST(FlatbufferTest,Prim)449 TEST(FlatbufferTest, Prim) {
450 Module m("m");
451 m.define(R"JIT(
452 def forward(self, x):
453 return int(x)
454 )JIT");
455
456 std::vector<IValue> inputs;
457 auto minput = 3.5 * torch::ones({});
458 inputs.emplace_back(minput);
459 auto ref = m.run_method("forward", minput);
460
461 CompilationOptions options;
462 mobile::Module bc = jitModuleToMobile(m, options);
463 IValue res;
464 for (int i = 0; i < 3; ++i) {
465 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
466 auto bcinputs = inputs;
467 res = bc.get_method("forward")(bcinputs);
468 }
469
470 auto resi = res.toInt();
471 auto refi = ref.toInt();
472 AT_ASSERT(resi == refi);
473
474 auto buff = save_mobile_module_to_bytes(bc);
475 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
476 for (int i = 0; i < 3; ++i) {
477 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
478 auto bcinputs = inputs;
479 res = bc2.get_method("forward")(bcinputs);
480 }
481 auto resi2 = res.toInt();
482 AT_ASSERT(resi2 == refi);
483 }
484
TEST(FlatbufferTest,PrimScalar)485 TEST(FlatbufferTest, PrimScalar) {
486 Module m("m");
487 m.define(R"JIT(
488 def forward(self, x):
489 return int(x.item())
490 )JIT");
491
492 std::vector<IValue> inputs;
493 auto minput = 3.5 * torch::ones({});
494 inputs.emplace_back(minput);
495 auto ref = m.run_method("forward", minput);
496
497 CompilationOptions options;
498 mobile::Module bc = jitModuleToMobile(m, options);
499 IValue res;
500 for (int i = 0; i < 3; ++i) {
501 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
502 auto bcinputs = inputs;
503 res = bc.get_method("forward")(bcinputs);
504 }
505
506 auto resi = res.toInt();
507 auto refi = ref.toInt();
508 AT_ASSERT(resi == refi);
509
510 auto buff = save_mobile_module_to_bytes(bc);
511 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
512 for (int i = 0; i < 3; ++i) {
513 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
514 auto bcinputs = inputs;
515 res = bc2.get_method("forward")(bcinputs);
516 }
517 auto resi2 = res.toInt();
518 AT_ASSERT(resi2 == refi);
519 }
520
TEST(FlatbufferTest,WrongMethodName)521 TEST(FlatbufferTest, WrongMethodName) {
522 Module m("m");
523 m.register_parameter("foo", torch::ones({}), false);
524 m.define(R"(
525 def add(self, x):
526 b = 4
527 return self.foo + x + b
528 )");
529 CompilationOptions options;
530 mobile::Module bc = jitModuleToMobile(m, options);
531 std::vector<IValue> inputs;
532 auto minput = 5 * torch::ones({});
533 inputs.emplace_back(minput);
534 ASSERT_THROWS_WITH_MESSAGE(
535 bc.get_method("forward")(inputs), "is not defined");
536
537 auto buff = save_mobile_module_to_bytes(bc);
538 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
539 ASSERT_THROWS_WITH_MESSAGE(
540 bc2.get_method("forward")(inputs), "is not defined");
541 }
542
TEST(FlatbufferTest,SetState)543 TEST(FlatbufferTest, SetState) {
544 Module m("m");
545 m.register_parameter("foo", torch::ones({}), false);
546 m.define(R"(
547 def __getstate__(self):
548 return self.foo
549 def __setstate__(self, a):
550 self.foo = a
551 def forward(self, x):
552 b = 4
553 return self.foo + x + b
554 )");
555
556 std::vector<IValue> inputs;
557 auto minput = 5 * torch::ones({});
558 inputs.emplace_back(minput);
559
560 std::stringstream ms;
561 m.save(ms);
562 auto loaded_m = load(ms);
563 auto ref = loaded_m.run_method("forward", minput);
564
565 CompilationOptions options;
566 mobile::Module bc = jitModuleToMobile(m, options);
567 IValue res;
568 for (int i = 0; i < 3; ++i) {
569 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
570 auto bcinputs = inputs;
571 res = bc.get_method("forward")(bcinputs);
572 }
573
574 auto resd = res.toTensor().item<float>();
575 auto refd = ref.toTensor().item<float>();
576 AT_ASSERT(resd == refd);
577
578 auto buff = save_mobile_module_to_bytes(bc);
579 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
580 for (int i = 0; i < 3; ++i) {
581 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
582 auto bcinputs = inputs;
583 res = bc2.get_method("forward")(bcinputs);
584 }
585
586 auto resd2 = res.toTensor().item<float>();
587 AT_ASSERT(resd2 == refd);
588 }
589
590 class TorchBindFlatbufferTestStruct : public torch::jit::CustomClassHolder {
591 public:
get(at::Tensor t)592 std::string get(at::Tensor t) {
593 std::stringstream ss;
594 ss << "Hello! Your tensor has ";
595 ss << t.numel();
596 ss << " elements!";
597 return ss.str();
598 }
599 };
600
601 namespace {
602 struct ClassNamespaceValue : public SugaredValue {
ClassNamespaceValuetorch::jit::__anon8587c2c70211::ClassNamespaceValue603 explicit ClassNamespaceValue(c10::QualifiedName name)
604 : basename_(std::move(name)) {}
605
attrtorch::jit::__anon8587c2c70211::ClassNamespaceValue606 std::shared_ptr<SugaredValue> attr(
607 const SourceRange& loc,
608 GraphFunction& m,
609 const std::string& name) override {
610 const auto fullName = c10::QualifiedName(basename_, name);
611
612 // Check to see if it is a custom class.
613 if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
614 return std::make_shared<ClassValue>(custom_class);
615 }
616
617 // If it's not a custom class, assume it's another namespace
618 // NOLINTNEXTLINE(performance-move-const-arg)
619 return std::make_shared<ClassNamespaceValue>(std::move(fullName));
620 }
621
kindtorch::jit::__anon8587c2c70211::ClassNamespaceValue622 std::string kind() const override {
623 return "Class Namespace";
624 }
625
626 private:
627 c10::QualifiedName basename_;
628 };
629
630 struct TestModuleResolver : public Resolver {
resolveValuetorch::jit::__anon8587c2c70211::TestModuleResolver631 std::shared_ptr<SugaredValue> resolveValue(
632 const std::string& name,
633 GraphFunction& m,
634 const SourceRange& loc) override {
635 if (name == "torch") {
636 return std::make_shared<BuiltinModule>("aten");
637 } else if (name == "__torch__") {
638 return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
639 }
640
641 return nullptr;
642 }
643
resolveTypetorch::jit::__anon8587c2c70211::TestModuleResolver644 TypePtr resolveType(const std::string& name, const SourceRange& loc)
645 override {
646 return nullptr;
647 }
648 };
649 } // namespace
650
TEST(FlatbufferTest,BuiltinClass)651 TEST(FlatbufferTest, BuiltinClass) {
652 script::Module m("m");
653
654 auto cls = getCustomClass(
655 "__torch__.torch.classes._TorchScriptTesting._FlatbufferTest");
656 TORCH_INTERNAL_ASSERT(cls);
657 c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
658 m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
659
660 m.register_parameter("foo", torch::ones({}), false);
661 m.define(
662 R"(
663 def __getstate__(self):
664 return 1
665 def __setstate__(self, a):
666 self.my_obj = __torch__.torch.classes._TorchScriptTesting._FlatbufferTest()
667
668 def forward(self, x) -> str:
669 return self.my_obj.get(x)
670 )",
671 std::make_shared<TestModuleResolver>());
672
673 CompilationOptions options;
674 mobile::Module bc = jitModuleToMobile(m, options);
675 auto buff = save_mobile_module_to_bytes(bc);
676 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
677 std::string expected = "Hello! Your tensor has 12 elements!";
678 auto res =
679 bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
680 const auto& str2 = res.toStringRef();
681 AT_ASSERT(str2 == expected);
682 }
683
TEST(FlatbufferTest,BuiltinFunction)684 TEST(FlatbufferTest, BuiltinFunction) {
685 script::Module m("m");
686 auto custom_class_obj = make_custom_class<TorchBindFlatbufferTestStruct>();
687 m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
688 m.define(R"(
689 def forward(self, x) -> str:
690 return self.my_obj.get(x)
691 )");
692
693 CompilationOptions options;
694 mobile::Module bc = jitModuleToMobile(m, options);
695 auto res =
696 bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
697 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
698 auto str = res.toStringRef();
699 std::string expected = "Hello! Your tensor has 12 elements!";
700 AT_ASSERT(str == expected);
701
702 auto buff = save_mobile_module_to_bytes(bc);
703 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
704 res = bc2.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
705 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
706 str = res.toStringRef();
707 AT_ASSERT(str == expected);
708 }
709
TEST(FlatbufferTest,Eval)710 TEST(FlatbufferTest, Eval) {
711 std::vector<torch::jit::IValue> inputs;
712
713 Module m("m");
714 m.define(R"(
715 def __init__(self, x):
716 self.training = True
717
718 def forward(self, input):
719 return torch.dropout(input, 1.0, self.training)
720 )");
721
722 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
723 inputs.push_back(torch::ones({1, 1, 28, 28}));
724 m.eval();
725 auto outputref = m.forward(inputs).toTensor();
726
727 // save m in training mode to make sure that mobile eval() will correctly
728 // change back to eval mode
729 m.train();
730 CompilationOptions options;
731 mobile::Module bc = jitModuleToMobile(m, options);
732 bc.eval();
733 IValue res;
734 for (int i = 0; i < 3; ++i) {
735 res = bc.get_method("forward")(inputs);
736 }
737 auto output = res.toTensor();
738 AT_ASSERT(outputref.dim() == output.dim());
739 AT_ASSERT(
740 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
741
742 auto buff = save_mobile_module_to_bytes(bc);
743 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
744 bc2.eval();
745 for (int i = 0; i < 3; ++i) {
746 res = bc2.get_method("forward")(inputs);
747 }
748 output = res.toTensor();
749 AT_ASSERT(outputref.dim() == output.dim());
750 AT_ASSERT(
751 outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
752 }
753
TEST(FlatbufferTest,FindWrongMethodName)754 TEST(FlatbufferTest, FindWrongMethodName) {
755 Module m("m");
756 m.register_parameter("foo", torch::ones({}), false);
757 m.define(R"(
758 def add(self, x):
759 b = 4
760 return self.foo + x + b
761 )");
762 CompilationOptions options;
763 mobile::Module bc = jitModuleToMobile(m, options);
764 ASSERT_TRUE(bc.find_method("forward") == std::nullopt);
765
766 auto buff = save_mobile_module_to_bytes(bc);
767 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
768 ASSERT_TRUE(bc2.find_method("forward") == std::nullopt);
769 }
770
TEST(FlatbufferTest,FindAndRunMethod)771 TEST(FlatbufferTest, FindAndRunMethod) {
772 Module m("m");
773 m.register_parameter("foo", torch::ones({}), false);
774 m.define(R"(
775 def add_it(self, x):
776 b = 4
777 return self.foo + x + b
778 )");
779
780 std::vector<IValue> inputs;
781 auto minput = 5 * torch::ones({});
782 inputs.emplace_back(minput);
783 auto ref = m.get_method("add_it")(inputs);
784
785 CompilationOptions options;
786 mobile::Module bc = jitModuleToMobile(m, options);
787 IValue res;
788 for (int i = 0; i < 3; ++i) {
789 auto bcinputs = inputs;
790 auto method = bc.find_method("add_it");
791 AT_ASSERT(method != std::nullopt);
792 res = (*method)(std::move(bcinputs));
793 }
794
795 auto resd = res.toTensor().item<float>();
796 auto refd = ref.toTensor().item<float>();
797 AT_ASSERT(resd == refd);
798
799 auto buff = save_mobile_module_to_bytes(bc);
800 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
801
802 for (int i = 0; i < 3; ++i) {
803 auto bcinputs = inputs;
804 auto method = bc2.find_method("add_it");
805 AT_ASSERT(method != std::nullopt);
806 res = (*method)(std::move(bcinputs));
807 }
808
809 resd = res.toTensor().item<float>();
810 AT_ASSERT(resd == refd);
811 }
812
TEST(FlatbufferTest,RunMethodVariadic)813 TEST(FlatbufferTest, RunMethodVariadic) {
814 Module m("m");
815 m.register_parameter("foo", torch::ones({}), false);
816 m.define(R"(
817 def add_three(self, x, y):
818 return self.foo + x + y
819 )");
820
821 std::vector<IValue> inputs;
822 auto inputx = 5 * torch::ones({});
823 auto inputy = 4 * torch::ones({});
824 auto ref = m.run_method("add_three", inputx, inputy);
825
826 CompilationOptions options;
827 mobile::Module bc = jitModuleToMobile(m, options);
828 IValue res = bc.run_method("add_three", inputx, inputy);
829
830 auto resd = res.toTensor().item<float>();
831 auto refd = ref.toTensor().item<float>();
832 AT_ASSERT(resd == refd);
833
834 auto buff = save_mobile_module_to_bytes(bc);
835 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
836 res = bc.run_method("add_three", inputx, inputy);
837 resd = res.toTensor().item<float>();
838 AT_ASSERT(resd == refd);
839 }
840
TEST(FlatbufferTest,DuplicateSetState)841 TEST(FlatbufferTest, DuplicateSetState) {
842 Module m("M");
843 m.register_parameter("foo", torch::ones({}), false);
844 m.define(R"(
845 def __getstate__(self):
846 return self.foo + self.foo
847 def __setstate__(self, a):
848 self.foo = a
849 def forward(self, x):
850 b = 4
851 return self.foo + x + b
852 )");
853
854 Module b("B");
855 b.register_module("M0", m);
856 b.register_module("M1", m);
857 b.define(R"(
858 def forward(self, x):
859 return self.M0.forward(x) + self.M1.forward(x)
860 )");
861
862 CompilationOptions options;
863 mobile::Module bc = jitModuleToMobile(m, options);
864 const auto methods = bc.get_methods();
865 const size_t expected_n = 3;
866 ASSERT_EQ(methods.size(), expected_n);
867
868 auto buff = save_mobile_module_to_bytes(bc);
869 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
870 const auto methods2 = bc.get_methods();
871 ASSERT_EQ(methods2.size(), expected_n);
872 }
873
TEST(FlatbufferTest,OpNameExportFetchRootOperators)874 TEST(FlatbufferTest, OpNameExportFetchRootOperators) {
875 torch::jit::Module m("m");
876 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
877 m.register_parameter("bias", torch::ones({20}), false);
878 m.define(R"(
879 def forward(self, input):
880 x1 = torch.zeros(2, 2)
881 x2 = torch.empty_like(torch.empty(2, 2))
882 x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
883 return (x1, x2, x3)
884 )");
885 m.eval();
886
887 CompilationOptions options;
888 mobile::Module ptl_model = jitModuleToMobile(m, options);
889 std::set<std::string> operator_names =
890 torch::jit::mobile::_export_operator_list(ptl_model);
891 std::set<std::string> expected_operator_names = {
892 "aten::_convolution",
893 "aten::empty.memory_format",
894 "aten::empty_like",
895 "aten::zeros",
896 };
897 EXPECT_EQ(operator_names, expected_operator_names)
898 << "Expected the root operator lists to be the same";
899
900 auto buff = save_mobile_module_to_bytes(ptl_model);
901 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
902 operator_names = torch::jit::mobile::_export_operator_list(bc2);
903 EXPECT_EQ(operator_names, expected_operator_names)
904 << "Expected the root operator lists to be the same";
905 }
906
TEST(FlatbufferTest,DefaultArgsConv)907 TEST(FlatbufferTest, DefaultArgsConv) {
908 auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
909 if (s && strcmp(s, "1") == 0)
910 return;
911
912 std::vector<torch::jit::IValue> inputs;
913
914 Module m("m");
915 m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
916 m.register_parameter("bias", torch::ones({20}), false);
917 m.define(R"(
918 def forward(self, input):
919 return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
920 )");
921
922 inputs.emplace_back(torch::ones({1, 1, 28, 28}));
923
924 auto outputref = m.forward(inputs).toTensor();
925
926 CompilationOptions options;
927 mobile::Module bc = jitModuleToMobile(m, options);
928 IValue res;
929 for (int i = 0; i < 1; ++i) {
930 res = bc.get_method("forward")(inputs);
931 }
932 auto output = res.toTensor();
933 AT_ASSERT(outputref.dim() == output.dim());
934 AT_ASSERT(output.equal(outputref));
935
936 auto buff = save_mobile_module_to_bytes(bc);
937 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
938 for (int i = 0; i < 1; ++i) {
939 res = bc2.get_method("forward")(inputs);
940 }
941 output = res.toTensor();
942 AT_ASSERT(outputref.dim() == output.dim());
943 AT_ASSERT(output.equal(outputref));
944 }
945
946 namespace {
testLiteModuleCompareResultTensors(Module & m,const std::vector<torch::jit::IValue> & inputs,const std::string & method_name="forward")947 void testLiteModuleCompareResultTensors(
948 Module& m,
949 const std::vector<torch::jit::IValue>& inputs,
950 const std::string& method_name = "forward") {
951 auto outputref = m.get_method(method_name)(inputs).toTensor();
952
953 CompilationOptions options;
954 mobile::Module bc = jitModuleToMobile(m, options);
955 IValue res;
956 for (int i = 0; i < 3; ++i) {
957 res = bc.get_method(method_name)(inputs);
958 }
959 auto output = res.toTensor();
960 AT_ASSERT(outputref.dim() == output.dim());
961 AT_ASSERT(output.equal(outputref));
962
963 auto buff = save_mobile_module_to_bytes(bc);
964 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
965 for (int i = 0; i < 3; ++i) {
966 res = bc2.get_method(method_name)(inputs);
967 }
968 output = res.toTensor();
969 AT_ASSERT(outputref.dim() == output.dim());
970 AT_ASSERT(output.equal(outputref));
971 }
972
testDefaultArgsPinv(int num_args)973 static void testDefaultArgsPinv(int num_args) {
974 Module m("m");
975 if (num_args == 1) {
976 m.define(R"(
977 def forward(self, input):
978 return torch.linalg_pinv(input)
979 )");
980 } else if (num_args == 2) {
981 m.define(R"(
982 def forward(self, input):
983 return torch.linalg_pinv(input, 1e-5)
984 )");
985 } else if (num_args == 3) {
986 m.define(R"(
987 def forward(self, input):
988 return torch.linalg_pinv(input, 1e-5, True)
989 )");
990 }
991
992 std::vector<torch::jit::IValue> inputs;
993 const int N = 28;
994 auto input = torch::range(1, N * N, 1);
995 input[0] = 1; // a more stable matrix
996 input = input.view({N, N});
997 inputs.emplace_back(input);
998 testLiteModuleCompareResultTensors(m, inputs);
999 }
1000 } // namespace
1001
1002 #if !defined FB_XPLAT_BUILD
TEST(FlatbufferTest,DefaultArgsPinv)1003 TEST(FlatbufferTest, DefaultArgsPinv) {
1004 // Test with different number of specified arguments.
1005 // Arguments not specified take default value.
1006 for (int num_args = 1; num_args <= 3; ++num_args) {
1007 testDefaultArgsPinv(num_args);
1008 }
1009
1010 // bytecode with one specified argument:
1011 // (6,
1012 // ('__torch__.m.forward',
1013 // (('instructions',
1014 // (('STOREN', 1, 2),
1015 // ('DROPR', 1, 0),
1016 // ('MOVE', 2, 0),
1017 // ('OP', 0, 0),
1018 // ('RET', 0, 0))),
1019 // ('operators', (('aten::linalg_pinv', '', 1),)),
1020 // ('constants', (False, 1e-15)), # default constants are not
1021 // used
1022 // ('types', ()),
1023 // ('register_size', 2)),
1024 // (('arguments',
1025 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1026 // None)),
1027 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1028 // None)))),
1029 // ('returns',
1030 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1031 // None)),)))))
1032
1033 // bytecode with 2 specified argument:
1034 // (6,
1035 // ('__torch__.m.forward',
1036 // (('instructions',
1037 // (('STOREN', 1, 2),
1038 // ('DROPR', 1, 0),
1039 // ('MOVE', 2, 0),
1040 // ('LOADC', 1, 0), # added LOADC for specified argument
1041 // ('OP', 0, 0),
1042 // ('RET', 0, 0))),
1043 // ('operators', (('aten::linalg_pinv', '', 2),)),
1044 // ('constants', (False, 1e-05)), # updated constant table
1045 // ('types', ()),
1046 // ('register_size', 2)),
1047 // (('arguments',
1048 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1049 // None)),
1050 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1051 // None)))),
1052 // ('returns',
1053 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1054 // None)),)))))
1055
1056 // bytecode with 3 specified arguments:
1057 // (6,
1058 // ('__torch__.m.forward',
1059 // (('instructions',
1060 // (('STOREN', 1, 2),
1061 // ('DROPR', 1, 0),
1062 // ('MOVE', 2, 0),
1063 // ('LOADC', 1, 0),
1064 // ('LOADC', 0, 0),
1065 // ('OP', 0, 0),
1066 // ('RET', 0, 0))),
1067 // ('operators', (('aten::linalg_pinv', '', 3),)),
1068 // ('constants', (True, 1e-05)),
1069 // ('types', ()),
1070 // ('register_size', 2)),
1071 // (('arguments',
1072 // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
1073 // None)),
1074 // (('name', 'input'), ('type', 'Tensor'), ('default_value',
1075 // None)))),
1076 // ('returns',
1077 // ((('name', ''), ('type', 'Tensor'), ('default_value',
1078 // None)),)))))
1079 }
1080
TEST(FlatbufferTest,DefaultArgsTensorinvSpecifyDefault)1081 TEST(FlatbufferTest, DefaultArgsTensorinvSpecifyDefault) {
1082 // The second argument is specified, but the value is the same as the default
1083 // value. It's treated as "not specified" since the value can be fetched from
1084 // schema.
1085 Module m("m");
1086 m.define(R"(
1087 def forward(self, input):
1088 return torch.linalg_tensorinv(input, 2)
1089 )");
1090 torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
1091 auto arg_nums = code.op_to_num_specified_args();
1092 ASSERT_EQ(arg_nums.size(), 1);
1093 ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
1094 std::vector<torch::jit::IValue> inputs;
1095 const int N = 4;
1096 auto input = torch::rand({N, N, N, N});
1097 inputs.emplace_back(input);
1098 testLiteModuleCompareResultTensors(m, inputs);
1099 }
1100
testDefaultArgsPinvWithOutArg(int num_args)1101 static void testDefaultArgsPinvWithOutArg(int num_args) {
1102 Module m("m");
1103 if (num_args == 1) {
1104 m.define(R"(
1105 def forward(self, input):
1106 return torch.linalg_pinv(input, out=input)
1107 )");
1108 } else if (num_args == 2) {
1109 m.define(R"(
1110 def forward(self, input):
1111 return torch.linalg_pinv(input, 1e-5, out=input)
1112 )");
1113 } else if (num_args == 3) {
1114 m.define(R"(
1115 def forward(self, input):
1116 return torch.linalg_pinv(input, 1e-5, True, out=input)
1117 )");
1118 }
1119
1120 const int N = 28;
1121 auto input = torch::range(1, N * N, 1);
1122 input[0] = 10000; // a more stable matrix
1123 input = input.view({N, N});
1124 auto ref = m.run_method("forward", input);
1125 TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
1126 TORCH_CHECK(input.equal(ref.toTensor()));
1127 }
1128
TEST(FlatbufferTest,DefaultArgsPinvWithOutArg)1129 TEST(FlatbufferTest, DefaultArgsPinvWithOutArg) {
1130 // Test with different number of specified arguments + out arg.
1131 // Arguments not specified take default value.
1132 for (int num_args = 1; num_args <= 3; ++num_args) {
1133 testDefaultArgsPinvWithOutArg(num_args);
1134 }
1135 }
1136
TEST(FlatbufferTest,DefaultArgsWithOutArg)1137 TEST(FlatbufferTest, DefaultArgsWithOutArg) {
1138 Module m("m");
1139 m.define(R"(
1140 def forward(self, x, h):
1141 torch.add(x, h, out=x)
1142 )");
1143
1144 std::vector<IValue> inputs;
1145 auto input_x = 2 * torch::ones({});
1146 auto input_h = torch::ones({});
1147 auto ref = m.run_method("forward", input_x, input_h);
1148
1149 CompilationOptions options;
1150 mobile::Module bc = jitModuleToMobile(m, options);
1151 bc.run_method("forward", input_x, input_h);
1152 AT_ASSERT(input_x.equal(4 * torch::ones({})));
1153
1154 auto buff = save_mobile_module_to_bytes(bc);
1155 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1156 auto input_x2 = 2 * torch::ones({});
1157 auto input_h2 = torch::ones({});
1158 m.run_method("forward", input_x2, input_h2);
1159 bc2.run_method("forward", input_x2, input_h2);
1160 AT_ASSERT(input_x2.equal(4 * torch::ones({})));
1161 }
1162
1163 #endif // !defined(FB_XPLAT_BUILD)
1164
1165 namespace {
1166 static auto reg =
1167 torch::class_<TorchBindFlatbufferTestStruct>(
1168 "_TorchScriptTesting",
1169 "_FlatbufferTest")
1170 .def(torch::init<>())
1171 .def("get", &TorchBindFlatbufferTestStruct::get)
1172 .def_pickle(
1173 // __getattr__
1174 [](const c10::intrusive_ptr<TorchBindFlatbufferTestStruct>& self)
__anon8587c2c70502(const c10::intrusive_ptr<TorchBindFlatbufferTestStruct>& self) 1175 -> int64_t { return 0; },
1176 // __setattr__
__anon8587c2c70602(int64_t state) 1177 [](int64_t state) {
1178 return c10::make_intrusive<TorchBindFlatbufferTestStruct>();
1179 });
1180
1181 } // namespace
1182
TEST(FlatbufferTest,OperatorCacheDifferentiatesDefaultArgs)1183 TEST(FlatbufferTest, OperatorCacheDifferentiatesDefaultArgs) {
1184 // Create 3 methods:
1185 //
1186 // 1. forward() returns a tensor with dtype=torch.int64 (4)
1187 // 2. forward2() returns a tensor with dtype=torch.float32 (6)
1188 // 3. forward3() returns a tensor with dtype=torch.float32 but
1189 // the dtype is inferred by the input tensor's dtype
1190 //
1191 // If caching works correctly, then the result from the full-jit
1192 // module and the lite module will be the same. Otherwise, it
1193 // will be different if we don't correctly ignore the cache
1194 // entry for an operator that has a different number of
1195 // arguments.
1196 Module m("m");
1197 m.define(R"(
1198 def forward(self):
1199 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
1200 return ret1.fill_(25)
1201 )");
1202 m.define(R"(
1203 def forward2(self):
1204 ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
1205 return ret1.fill_(32.0)
1206 )");
1207 m.define(R"(
1208 def forward3(self):
1209 ret1 = torch.new_empty(torch.zeros(10), [10])
1210 return ret1.fill_(12.0)
1211 )");
1212
1213 std::vector<torch::jit::IValue> inputs;
1214 testLiteModuleCompareResultTensors(m, inputs, "forward");
1215 testLiteModuleCompareResultTensors(m, inputs, "forward2");
1216 testLiteModuleCompareResultTensors(m, inputs, "forward3");
1217 }
1218
TEST(FlatbufferTest,OperatorSize1)1219 TEST(FlatbufferTest, OperatorSize1) {
1220 Module m("m");
1221 m.define(R"(
1222 def forward(self, input: Tensor, scale:float):
1223 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1224 )");
1225
1226 CompilationOptions options;
1227 mobile::Module bc = jitModuleToMobile(m, options);
1228 const auto& func = bc.get_method("forward").function();
1229 ASSERT_EQ(
1230 func.get_code().operator_input_sizes_.size(),
1231 func.get_code().operators_.size());
1232
1233 auto buff = save_mobile_module_to_bytes(bc);
1234 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1235 const auto& func2 = bc.get_method("forward").function();
1236 ASSERT_EQ(
1237 func2.get_code().operator_input_sizes_.size(),
1238 func2.get_code().operators_.size());
1239 }
1240
TEST(FlatbufferTest,BoolAndDoubleList)1241 TEST(FlatbufferTest, BoolAndDoubleList) {
1242 Module m("m");
1243 c10::List<bool> boollist;
1244 boollist.push_back(false);
1245 IValue boollist_ival = boollist;
1246 IValue doublelist = std::vector<double>{2.0};
1247 m.register_attribute("bool_list", boollist_ival.type(), boollist_ival);
1248 m.register_attribute("double_list", doublelist.type(), doublelist);
1249
1250 CompilationOptions options;
1251 mobile::Module bc = jitModuleToMobile(m, options);
1252 auto buff = save_mobile_module_to_bytes(bc);
1253 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1254
1255 // if the variables read are wrong type the conversion will raise exception
1256 auto boolval = bc2.attr("bool_list", {}).toBoolList().get(0);
1257 auto doubleval = bc2.attr("double_list", {}).toDoubleList().get(0);
1258
1259 ASSERT_EQ(boolval, false);
1260 ASSERT_EQ(doubleval, 2.0);
1261 }
1262
TEST(FlatbufferTest,OperatorTest2)1263 TEST(FlatbufferTest, OperatorTest2) { // NOLINT (use =delete in gtest)
1264 const std::vector<std::string> test_programs{
1265 // test invoking a method with default parameter
1266 R"(
1267 def test_func(self, x, b : int = 4):
1268 return self.foo + x + b
1269 )",
1270 // inner method call with default parameter (gets inlined)
1271 R"(
1272 def add_with_default_arg(self, x, b : int = 4):
1273 return self.foo + x + b
1274 def test_func(self, x):
1275 return self.add_with_default_arg(x) # invoke method w/ default arg
1276 )",
1277 // simple method call
1278 R"(
1279 def test_func(self, x):
1280 b = 4
1281 return self.foo + x + b
1282 )",
1283 };
1284 for (const auto& test_program : test_programs) {
1285 Module m("m");
1286 m.register_parameter("foo", torch::ones({}), false);
1287 m.define(test_program);
1288
1289 CompilationOptions options;
1290 mobile::Module bc = jitModuleToMobile(m, options);
1291 const auto& func = bc.get_method("test_func").function();
1292 ASSERT_EQ(
1293 func.get_code().operator_input_sizes_.size(),
1294 func.get_code().operators_.size());
1295
1296 auto buff = save_mobile_module_to_bytes(bc);
1297 mobile::Module bc2 = parse_mobile_module(buff->data(), buff->size());
1298 const auto& func2 = bc.get_method("test_func").function();
1299 ASSERT_EQ(
1300 func2.get_code().operator_input_sizes_.size(),
1301 func2.get_code().operators_.size());
1302 }
1303 }
1304
jitModuleFromBuffer(void * data,size_t size)1305 Module jitModuleFromBuffer(void* data, size_t size) {
1306 // Make a copy of the data so we can use the existing API, which takes
1307 // ownership. The `data` param might point into the middle of a buffer, so we
1308 // can't safely take ownership of it directly.
1309 // @nolint CLANGTIDY cppcoreguidelines-no-malloc
1310 std::shared_ptr<char> copy(static_cast<char*>(malloc(size)), free);
1311 memcpy(copy.get(), data, size);
1312
1313 ExtraFilesMap extra_files;
1314 return parse_and_initialize_jit_module(std::move(copy), size, extra_files);
1315 }
1316
TEST(TestSourceFlatbuffer,UpsampleNearest2d)1317 TEST(TestSourceFlatbuffer, UpsampleNearest2d) {
1318 Module m("m");
1319 m.define(R"(
1320 def forward(self, input: Tensor, scale:float):
1321 return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
1322 )");
1323
1324 std::vector<IValue> inputs;
1325 inputs.emplace_back(torch::rand({1, 3, 128, 128}));
1326 inputs.emplace_back(at::Scalar(2.0));
1327 auto ref = m.forward(inputs);
1328
1329 std::stringstream ss;
1330 m._save_for_mobile(ss, {}, false, /*use_fatbuffer=*/true);
1331 auto mm = _load_for_mobile(ss);
1332 auto m2 = load(ss);
1333
1334 auto res = m2.forward(inputs);
1335 auto resm = mm.forward(inputs);
1336
1337 auto resd = res.toTensor();
1338 auto refd = ref.toTensor();
1339 auto resmd = resm.toTensor();
1340 ASSERT_TRUE(resd.equal(refd));
1341 ASSERT_TRUE(resmd.equal(refd));
1342 }
1343
TEST(TestSourceFlatbuffer,CheckAttrAccess)1344 TEST(TestSourceFlatbuffer, CheckAttrAccess) {
1345 Module m("m");
1346 m.register_attribute("mobile_optimized", BoolType::get(), true);
1347 auto data = save_jit_module_to_bytes(m);
1348 Module m2 = jitModuleFromBuffer(data->data(), data->size());
1349 bool mobile_optimized = m2.attr("mobile_optimized", false).toBool();
1350 AT_ASSERT(mobile_optimized);
1351 mobile::Module m3 = parse_mobile_module(data->data(), data->size());
1352 mobile_optimized = m3.attr("mobile_optimized", false).toBool();
1353 AT_ASSERT(mobile_optimized);
1354 }
1355
TEST(TestSourceFlatbuffer,MethodInvocation)1356 TEST(TestSourceFlatbuffer,
1357 MethodInvocation) { // NOLINT (use =delete in gtest)
1358 const std::vector<std::string> test_programs{
1359 // test invoking a method with default parameter
1360 R"(
1361 def test_func(self, x, b : int = 4):
1362 return self.foo + x + b
1363 )",
1364 // inner method call with default parameter (gets inlined)
1365 R"(
1366 def add_with_default_arg(self, x, b : int = 4):
1367 return self.foo + x + b
1368 def test_func(self, x):
1369 return self.add_with_default_arg(x) # invoke method w/ default arg
1370 )",
1371 // simple method call
1372 R"(
1373 def test_func(self, x):
1374 b = 4
1375 return self.foo + x + b
1376 )",
1377 };
1378 for (const auto& test_program : test_programs) {
1379 Module m("m");
1380 m.register_parameter("foo", torch::ones({}), false);
1381 m.define(test_program);
1382
1383 const int fortyTwo = 42; // (keep linter happy)
1384 auto minput = fortyTwo * torch::ones({});
1385 auto ref = m.run_method("test_func", minput);
1386
1387 auto data = save_jit_module_to_bytes(m);
1388 Module m2 = jitModuleFromBuffer(data->data(), data->size());
1389 const auto& test_func = m2.get_method("test_func");
1390 IValue res;
1391 for (int i = 0; i < 3; ++i) {
1392 res = test_func({minput});
1393 }
1394 auto resd = res.toTensor().item<float>();
1395 auto refd = ref.toTensor().item<float>();
1396 AT_ASSERT(resd == refd);
1397
1398 mobile::Module m3 = parse_mobile_module(data->data(), data->size());
1399 const auto& test_func3 = m3.get_method("test_func");
1400 for (int i = 0; i < 3; ++i) {
1401 res = test_func3({minput});
1402 }
1403 resd = res.toTensor().item<float>();
1404 refd = ref.toTensor().item<float>();
1405 AT_ASSERT(resd == refd);
1406 }
1407 }
1408
1409 #if !defined FB_XPLAT_BUILD
1410 // The following test run in fbcode only
TEST(FlatbufferUpgraderTest,DivTensorV2)1411 TEST(FlatbufferUpgraderTest, DivTensorV2) {
1412 std::string filePath(__FILE__);
1413 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1414 test_model_file.append("upgrader_models/test_versioned_div_tensor_v2.ptl.ff");
1415 /*
1416 (('__torch__.MyModule.forward',
1417 (('instructions',
1418 (('STOREN', 1, 3),
1419 ('DROPR', 1, 0),
1420 ('LOAD', 2, 0),
1421 ('LOAD', 3, 0),
1422 ('OP', 0, 0),
1423 ('LOAD', 2, 0),
1424 ('LOAD', 3, 0),
1425 ('OP', 1, 0),
1426 ('MOVE', 2, 0),
1427 ('MOVE', 3, 0),
1428 ('OP', 2, 0),
1429 ('TUPLE_CONSTRUCT', 3, 0),
1430 ('RET', 0, 0))),
1431 ('operators',
1432 (('aten::div', 'Tensor'),
1433 ('aten::div', 'Tensor'),
1434 ('aten::div', 'Tensor'))),
1435 ('constants', ()),
1436 ('types', ()),
1437 ('register_size', 3))),)
1438
1439 */
1440 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1441 auto intrsuction_list =
1442 m_module.get_method("forward").function().get_code().instructions_;
1443 uint64_t number_of_call_instruction = 0;
1444 for (auto& instruction : intrsuction_list) {
1445 number_of_call_instruction += (instruction.op == OpCode::CALL);
1446 }
1447 // 3 operators will use upgrader
1448 ASSERT_EQ(number_of_call_instruction, 3);
1449
1450 std::vector<IValue> inputs = {
1451 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1452 auto actual_output = m_module.forward(inputs);
1453 auto expect_output = 2.0 * torch::ones({1});
1454 auto actual_output_list = actual_output.toTuple()->elements();
1455 ASSERT_TRUE(actual_output_list[0].toTensor().equal(expect_output));
1456 }
1457
TEST(FlatbufferUpgraderTest,DivTensorOutV2)1458 TEST(FlatbufferUpgraderTest, DivTensorOutV2) {
1459 std::string filePath(__FILE__);
1460 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1461 test_model_file.append(
1462 "upgrader_models/test_versioned_div_tensor_out_v2.ptl.ff");
1463 /*
1464 (('__torch__.MyModule.forward',
1465 (('instructions',
1466 (('STOREN', 1, 4),
1467 ('DROPR', 1, 0),
1468 ('MOVE', 2, 0),
1469 ('MOVE', 3, 0),
1470 ('MOVE', 4, 0),
1471 ('OP', 0, 0),
1472 ('RET', 0, 0))),
1473 ('operators', (('aten::div', 'out'),)),
1474 ('constants', ()),
1475 ('types', ()),
1476 ('register_size', 4))),)
1477 */
1478 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1479
1480 auto intrsuction_list =
1481 m_module.get_method("forward").function().get_code().instructions_;
1482 uint64_t number_of_call_instruction = 0;
1483 for (auto& instruction : intrsuction_list) {
1484 number_of_call_instruction += (instruction.op == OpCode::CALL);
1485 }
1486 // One operator will use upgrader
1487 ASSERT_EQ(number_of_call_instruction, 1);
1488
1489 std::vector<IValue> inputs{
1490 IValue(6 * torch::ones({1})),
1491 IValue(3 * torch::ones({1})),
1492 IValue(torch::empty({1}))};
1493 m_module.forward(inputs);
1494 auto expect_output = 2.0 * torch::ones({1});
1495 auto actual_output = inputs[2].toTensor();
1496 // The out argument will be overwritten with the output
1497 ASSERT_TRUE(actual_output.equal(expect_output));
1498 }
1499
TEST(FlatbufferUpgraderTest,DivTensorInplaceV2)1500 TEST(FlatbufferUpgraderTest, DivTensorInplaceV2) {
1501 std::string filePath(__FILE__);
1502 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1503 test_model_file.append(
1504 "upgrader_models/test_versioned_div_tensor_inplace_v2.ptl.ff");
1505 /*
1506 (('__torch__.MyModule.forward',
1507 (('instructions',
1508 (('STOREN', 1, 3),
1509 ('DROPR', 1, 0),
1510 ('MOVE', 2, 0),
1511 ('MOVE', 3, 0),
1512 ('OP', 0, 0),
1513 ('RET', 0, 0))),
1514 ('operators', (('aten::div_', 'Tensor'),)),
1515 ('constants', ()),
1516 ('types', ()),
1517 ('register_size', 3))),)
1518 */
1519 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1520
1521 auto intrsuction_list =
1522 m_module.get_method("forward").function().get_code().instructions_;
1523 uint64_t number_of_call_instruction = 0;
1524 for (auto& instruction : intrsuction_list) {
1525 number_of_call_instruction += (instruction.op == OpCode::CALL);
1526 }
1527 // One operator will use upgrader
1528 ASSERT_EQ(number_of_call_instruction, 1);
1529
1530 std::vector<IValue> inputs{
1531 IValue(6 * torch::ones({1})), IValue(3 * torch::ones({1}))};
1532 m_module.forward(inputs);
1533 auto expect_output = 2.0 * torch::ones({1});
1534 auto actual_output = inputs[0].toTensor();
1535 // The out argument will be overwritten with the output
1536 ASSERT_TRUE(actual_output.equal(expect_output));
1537 }
1538
TEST(FlatbufferUpgraderTest,DivScalarFloatV2)1539 TEST(FlatbufferUpgraderTest, DivScalarFloatV2) {
1540 std::string filePath(__FILE__);
1541 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1542 test_model_file.append(
1543 "upgrader_models/test_versioned_div_scalar_float_v2.ptl.ff");
1544 /*
1545 (('__torch__.MyModuleFloat.forward',
1546 (('instructions',
1547 (('STOREN', 1, 3),
1548 ('DROPR', 1, 0),
1549 ('MOVE', 2, 0),
1550 ('MOVE', 3, 0),
1551 ('OP', 0, 0),
1552 ('RET', 0, 0))),
1553 ('operators', (('aten::div', 'Scalar'),)),
1554 ('constants', ()),
1555 ('types', ()),
1556 ('register_size', 3))),)
1557 */
1558
1559 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1560
1561 auto intrsuction_list =
1562 m_module.get_method("forward").function().get_code().instructions_;
1563 uint64_t number_of_call_instruction = 0;
1564 for (auto& instruction : intrsuction_list) {
1565 number_of_call_instruction += (instruction.op == OpCode::CALL);
1566 }
1567 // One operator will use upgrader
1568 ASSERT_EQ(number_of_call_instruction, 1);
1569
1570 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1571 auto output = m_module.forward(inputs);
1572 auto expect_output = 2.0 * torch::ones({1});
1573 auto actual_output = output.toTensor();
1574
1575 // The out argument will be overwritten with the output
1576 ASSERT_TRUE(actual_output.equal(expect_output));
1577 }
1578
TEST(FlatbufferUpgraderTest,DivScalarReciprocalFloatV2)1579 TEST(FlatbufferUpgraderTest, DivScalarReciprocalFloatV2) {
1580 std::string filePath(__FILE__);
1581 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1582 test_model_file.append(
1583 "upgrader_models/test_versioned_div_scalar_reciprocal_float_v2.ptl.ff");
1584 /*
1585 (('__torch__.MyModuleFloat.forward',
1586 (('instructions',
1587 (('STOREN', 1, 3),
1588 ('DROPR', 1, 0),
1589 ('MOVE', 2, 0),
1590 ('OP', 0, 0),
1591 ('MOVE', 3, 0),
1592 ('OP', 1, 0),
1593 ('RET', 0, 0))),
1594 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1595 ('constants', ()),
1596 ('types', ()),
1597 ('register_size', 3))),)
1598 */
1599 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1600
1601 auto intrsuction_list =
1602 m_module.get_method("forward").function().get_code().instructions_;
1603 uint64_t number_of_call_instruction = 0;
1604 for (auto& instruction : intrsuction_list) {
1605 number_of_call_instruction += (instruction.op == OpCode::CALL);
1606 }
1607 // No operator will use upgrader
1608 ASSERT_EQ(number_of_call_instruction, 0);
1609
1610 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1611 auto output = m_module.forward(inputs);
1612 auto expect_output = 0.5 * torch::ones({1});
1613 auto actual_output = output.toTensor();
1614 // The out argument will be overwritten with the output
1615 ASSERT_TRUE(actual_output.equal(expect_output));
1616 }
1617
TEST(FlatbufferUpgraderTest,DivScalarReciprocalIntV2)1618 TEST(FlatbufferUpgraderTest, DivScalarReciprocalIntV2) {
1619 std::string filePath(__FILE__);
1620 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1621 test_model_file.append(
1622 "upgrader_models/test_versioned_div_scalar_reciprocal_int_v2.ptl.ff");
1623 /*
1624 (('__torch__.MyModuleInt.forward',
1625 (('instructions',
1626 (('STOREN', 1, 3),
1627 ('DROPR', 1, 0),
1628 ('MOVE', 2, 0),
1629 ('OP', 0, 0),
1630 ('MOVE', 3, 0),
1631 ('OP', 1, 0),
1632 ('RET', 0, 0))),
1633 ('operators', (('aten::reciprocal', ''), ('aten::mul', 'Scalar'))),
1634 ('constants', ()),
1635 ('types', ()),
1636 ('register_size', 3))),)
1637 */
1638 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1639
1640 auto intrsuction_list =
1641 m_module.get_method("forward").function().get_code().instructions_;
1642 uint64_t number_of_call_instruction = 0;
1643 for (auto& instruction : intrsuction_list) {
1644 number_of_call_instruction += (instruction.op == OpCode::CALL);
1645 }
1646 // No operator will use upgrader
1647 ASSERT_EQ(number_of_call_instruction, 0);
1648
1649 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1650 auto output = m_module.forward(inputs);
1651 auto expect_output = 0.5 * torch::ones({1});
1652 auto actual_output = output.toTensor();
1653
1654 // The out argument will be overwritten with the output
1655 ASSERT_TRUE(actual_output.equal(expect_output));
1656 }
1657
TEST(FlatbufferUpgraderTest,DivScalarScalarV2)1658 TEST(FlatbufferUpgraderTest, DivScalarScalarV2) {
1659 std::string filePath(__FILE__);
1660 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1661 test_model_file.append(
1662 "upgrader_models/test_versioned_div_scalar_scalar_v2.ptl.ff");
1663 /*
1664 (('__torch__.MyModule.forward',
1665 (('instructions',
1666 (('STOREN', 1, 5),
1667 ('DROPR', 1, 0),
1668 ('LOAD', 2, 0),
1669 ('LOAD', 3, 0),
1670 ('OP', 0, 0),
1671 ('MOVE', 2, 0),
1672 ('LOAD', 4, 0),
1673 ('OP', 1, 0),
1674 ('LOAD', 3, 0),
1675 ('MOVE', 4, 0),
1676 ('OP', 2, 0),
1677 ('MOVE', 3, 0),
1678 ('MOVE', 5, 0),
1679 ('OP', 3, 0),
1680 ('TUPLE_CONSTRUCT', 4, 0),
1681 ('RET', 0, 0))),
1682 ('operators',
1683 (('aten::div', ''),
1684 ('aten::div', 'float'),
1685 ('aten::div', ''),
1686 ('aten::div', 'int'))),
1687 ('constants', ()),
1688 ('types', ()),
1689 ('register_size', 5))),)
1690 */
1691 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1692 auto intrsuction_list =
1693 m_module.get_method("forward").function().get_code().instructions_;
1694 uint64_t number_of_call_instruction = 0;
1695 for (auto& instruction : intrsuction_list) {
1696 number_of_call_instruction += (instruction.op == OpCode::CALL);
1697 }
1698 // No operator will use upgrader
1699 ASSERT_EQ(number_of_call_instruction, 0);
1700
1701 std::vector<IValue> inputs{IValue(20.0), IValue(10), IValue(2.0), IValue(5)};
1702 auto output = m_module.forward(inputs);
1703 auto output_list = output.toTupleRef().elements();
1704 auto expect_output = std::vector<IValue>(
1705 {IValue(2.0), IValue(10.0), IValue(5.0), IValue(2.0)});
1706 // auto actual_output = output.toTensor();
1707 for (size_t i = 0; i < expect_output.size(); i++) {
1708 ASSERT_EQ(output_list[i], expect_output[i]);
1709 }
1710 }
1711
TEST(FlatbufferUpgraderTest,DivScalarIntV2)1712 TEST(FlatbufferUpgraderTest, DivScalarIntV2) {
1713 std::string filePath(__FILE__);
1714 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1715 test_model_file.append(
1716 "upgrader_models/test_versioned_div_scalar_int_v2.ptl.ff");
1717 /*
1718 (('__torch__.MyModuleInt.forward',
1719 (('instructions',
1720 (('STOREN', 1, 3),
1721 ('DROPR', 1, 0),
1722 ('MOVE', 2, 0),
1723 ('MOVE', 3, 0),
1724 ('OP', 0, 0),
1725 ('RET', 0, 0))),
1726 ('operators', (('aten::div', 'Scalar'),)),
1727 ('constants', ()),
1728 ('types', ()),
1729 ('register_size', 3))),)
1730 */
1731 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1732
1733 auto intrsuction_list =
1734 m_module.get_method("forward").function().get_code().instructions_;
1735 uint64_t number_of_call_instruction = 0;
1736 for (auto& instruction : intrsuction_list) {
1737 number_of_call_instruction += (instruction.op == OpCode::CALL);
1738 }
1739 // One operator will use upgrader
1740 ASSERT_EQ(number_of_call_instruction, 1);
1741
1742 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1743 auto output = m_module.forward(inputs);
1744 auto expect_output = 2.0 * torch::ones({1});
1745 auto actual_output = output.toTensor();
1746
1747 // The out argument will be overwritten with the output
1748 ASSERT_TRUE(actual_output.equal(expect_output));
1749 }
1750
TEST(FlatbufferUpgraderTest,DivScalarInplaceFloatV2)1751 TEST(FlatbufferUpgraderTest, DivScalarInplaceFloatV2) {
1752 std::string filePath(__FILE__);
1753 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1754 test_model_file.append(
1755 "upgrader_models/test_versioned_div_scalar_inplace_float_v2.ptl.ff");
1756 /*
1757 (('__torch__.MyModuleFloat.forward',
1758 (('instructions',
1759 (('STOREN', 1, 3),
1760 ('DROPR', 1, 0),
1761 ('MOVE', 2, 0),
1762 ('MOVE', 3, 0),
1763 ('OP', 0, 0),
1764 ('RET', 0, 0))),
1765 ('operators', (('aten::div_', 'Scalar'),)),
1766 ('constants', ()),
1767 ('types', ()),
1768 ('register_size', 3))),)
1769 */
1770
1771 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1772
1773 auto intrsuction_list =
1774 m_module.get_method("forward").function().get_code().instructions_;
1775 uint64_t number_of_call_instruction = 0;
1776 for (auto& instruction : intrsuction_list) {
1777 number_of_call_instruction += (instruction.op == OpCode::CALL);
1778 }
1779 // One operator will use upgrader
1780 ASSERT_EQ(number_of_call_instruction, 1);
1781
1782 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3.0)};
1783 auto output = m_module.forward(inputs);
1784 auto expect_output = 2.0 * torch::ones({1});
1785 auto actual_output = output.toTensor();
1786
1787 // The out argument will be overwritten with the output
1788 ASSERT_TRUE(actual_output.equal(expect_output));
1789 }
1790
TEST(FlatbufferUpgraderTest,DivScalarInplaceIntV2)1791 TEST(FlatbufferUpgraderTest, DivScalarInplaceIntV2) {
1792 std::string filePath(__FILE__);
1793 auto test_model_file = filePath.substr(0, filePath.find_last_of("/\\") + 1);
1794 test_model_file.append(
1795 "upgrader_models/test_versioned_div_scalar_inplace_int_v2.ptl.ff");
1796 /*
1797 (('__torch__.MyModuleInt.forward',
1798 (('instructions',
1799 (('STOREN', 1, 3),
1800 ('DROPR', 1, 0),
1801 ('MOVE', 2, 0),
1802 ('MOVE', 3, 0),
1803 ('OP', 0, 0),
1804 ('RET', 0, 0))),
1805 ('operators', (('aten::div_', 'Scalar'),)),
1806 ('constants', ()),
1807 ('types', ()),
1808 ('register_size', 3))),)
1809 */
1810
1811 mobile::Module m_module = load_mobile_module_from_file(test_model_file);
1812
1813 auto intrsuction_list =
1814 m_module.get_method("forward").function().get_code().instructions_;
1815 uint64_t number_of_call_instruction = 0;
1816 for (auto& instruction : intrsuction_list) {
1817 number_of_call_instruction += (instruction.op == OpCode::CALL);
1818 }
1819 // One operator will use upgrader
1820 ASSERT_EQ(number_of_call_instruction, 1);
1821
1822 std::vector<IValue> inputs{IValue(6 * torch::ones({1})), IValue(3)};
1823 auto output = m_module.forward(inputs);
1824 auto expect_output = 2.0 * torch::ones({1});
1825 auto actual_output = output.toTensor();
1826
1827 // The out argument will be overwritten with the output
1828 ASSERT_TRUE(actual_output.equal(expect_output));
1829 }
1830
1831 #endif // !defined(FB_XPLAT_BUILD)
1832
1833 //
1834 // Tests that need access to internal flatbuffers types/functions.
1835 // Do not add any other tests after this section.
1836 //
1837
1838 } // namespace jit
1839 } // namespace torch
1840 namespace torch {
1841 namespace jit {
1842
1843 /**
1844 * An Allocator that can only deallocate (using delete []), counting
1845 * the number of times that it has been asked to deallocate.
1846 */
1847 class TestAllocator : public flatbuffers::Allocator {
1848 public:
1849 /**
1850 * *deallocate_call_count will be incremented whenever deallocate() is called.
1851 */
TestAllocator(int * deallocate_call_count)1852 explicit TestAllocator(int* deallocate_call_count)
1853 : deallocate_call_count_(deallocate_call_count) {}
1854
deallocate(uint8_t * p,size_t)1855 void deallocate(uint8_t* p, size_t /*size*/) override {
1856 *deallocate_call_count_ += 1;
1857 delete[] p;
1858 }
1859
allocate(size_t)1860 uint8_t* allocate(size_t) override {
1861 TORCH_CHECK(false, "allocate() should not be called");
1862 }
reallocate_downward(uint8_t *,size_t,size_t,size_t,size_t)1863 uint8_t* reallocate_downward(uint8_t*, size_t, size_t, size_t, size_t)
1864 override {
1865 TORCH_CHECK(false, "reallocate_downward() should not be called");
1866 }
1867
1868 private:
1869 int* deallocate_call_count_;
1870 };
1871
1872 /// Provides access to DetachedBuffer::destroy().
1873 struct DetachedBufferTestingFriend {
1874 /// Returns a UniqueDetachedBuffer that wraps the provided DetachedBuffer.
1875 /// A copy of similar code in flatbuffer_serializer.cpp.
make_unique_detached_buffertorch::jit::DetachedBufferTestingFriend1876 static DetachedBuffer::UniqueDetachedBuffer make_unique_detached_buffer(
1877 DetachedBuffer* buf) {
1878 return DetachedBuffer::UniqueDetachedBuffer(buf, DetachedBuffer::destroy);
1879 }
1880 };
1881
TEST(FlatbufferTest,DetachedBufferSmoke)1882 TEST(FlatbufferTest, DetachedBufferSmoke) {
1883 // Use a custom Allocator to watch the lifecycle of a
1884 // flatbuffers::DetachedBuffer.
1885 int deallocate_call_count = 0;
1886 TestAllocator alloc(&deallocate_call_count);
1887
1888 // Data for the buffer. TestAllocator will free it with `delete []`.
1889 constexpr size_t data_size = 4;
1890 uint8_t* data = new uint8_t[data_size];
1891
1892 // An internal buffer on the stack that owns the data.
1893 flatbuffers::DetachedBuffer fb_buf_local(
1894 &alloc, /*own_allocator=*/false, data, data_size, data, data_size);
1895 EXPECT_EQ(fb_buf_local.data(), data);
1896 EXPECT_EQ(fb_buf_local.size(), data_size);
1897
1898 // Mimic the code inside save_mobile_module_to_bytes by transferring ownership
1899 // to a heap object.
1900 auto fb_buf_ptr = new flatbuffers::DetachedBuffer(std::move(fb_buf_local));
1901 // The data should not have been deleted yet.
1902 EXPECT_EQ(deallocate_call_count, 0);
1903 // The new object points to the data.
1904 EXPECT_EQ(fb_buf_ptr->data(), data);
1905 EXPECT_EQ(fb_buf_ptr->size(), data_size);
1906 // The old object points to nothing.
1907 // @lint-ignore CLANGTIDY bugprone-use-after-move
1908 EXPECT_EQ(fb_buf_local.data(), nullptr);
1909 // @lint-ignore CLANGTIDY bugprone-use-after-move
1910 EXPECT_EQ(fb_buf_local.size(), 0);
1911
1912 // The top-level torch::jit::DetachedBuffer.
1913 auto wrapped_buf =
1914 new DetachedBuffer(fb_buf_ptr->data(), fb_buf_ptr->size(), fb_buf_ptr);
1915 EXPECT_EQ(wrapped_buf->data(), data);
1916 EXPECT_EQ(wrapped_buf->size(), data_size);
1917
1918 // The unique_ptr that owns the torch::jit::DetachedBuffer and its contents.
1919 {
1920 DetachedBuffer::UniqueDetachedBuffer unique_buf =
1921 DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf);
1922 EXPECT_EQ(unique_buf->data(), data);
1923 EXPECT_EQ(unique_buf->size(), data_size);
1924
1925 // The data should not have been deleted yet.
1926 EXPECT_EQ(deallocate_call_count, 0);
1927 }
1928
1929 // Now that the unique_ptr is out of scope, the data should have been deleted.
1930 EXPECT_EQ(deallocate_call_count, 1);
1931 }
1932
TEST(FlatbufferTest,DetachedBufferNullOwner)1933 TEST(FlatbufferTest, DetachedBufferNullOwner) {
1934 // a torch::jit::DetachedBuffer with a null internal owner.
1935 std::vector<uint8_t> data(4);
1936 auto wrapped_buf = new DetachedBuffer(data.data(), data.size());
1937
1938 // A unique_ptr that owns the torch::jit::DetachedBuffer and its contents.
1939 {
1940 DetachedBuffer::UniqueDetachedBuffer unique_buf =
1941 DetachedBufferTestingFriend::make_unique_detached_buffer(wrapped_buf);
1942 EXPECT_EQ(unique_buf->data(), data.data());
1943 EXPECT_EQ(unique_buf->size(), data.size());
1944 }
1945
1946 // The DetachedBuffer should have been destroyed when the UniqueDetachedBuffer
1947 // went out of scope. If we didn't crash or get any ASAN warnings, we should
1948 // be good.
1949 }
1950
1951 //
1952 // Do not add tests here unless they require flatbuffers types. See comment at
1953 // the beginning of this section.
1954 //
1955
1956 } // namespace jit
1957 } // namespace torch
1958