1 #include <test/cpp/jit/test_utils.h>
2
3 #include <gtest/gtest.h>
4
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/mobile/import.h>
9 #include <torch/csrc/jit/mobile/import_data.h>
10 #include <torch/csrc/jit/mobile/module.h>
11 #include <torch/csrc/jit/mobile/train/export_data.h>
12 #include <torch/csrc/jit/mobile/train/optim/sgd.h>
13 #include <torch/csrc/jit/mobile/train/random.h>
14 #include <torch/csrc/jit/mobile/train/sequential.h>
15 #include <torch/csrc/jit/serialization/import.h>
16 #include <torch/data/dataloader.h>
17 #include <torch/torch.h>
18
19 // Tests go in torch::jit
20 namespace torch {
21 namespace jit {
22
TEST(LiteTrainerTest,Params)23 TEST(LiteTrainerTest, Params) {
24 Module m("m");
25 m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
26 m.define(R"(
27 def forward(self, x):
28 b = 1.0
29 return self.foo * x + b
30 )");
31 double learning_rate = 0.1, momentum = 0.1;
32 int n_epoc = 10;
33 // init: y = x + 1;
34 // target: y = 2 x + 1
35 std::vector<std::pair<Tensor, Tensor>> trainData{
36 {1 * torch::ones({1}), 3 * torch::ones({1})},
37 };
38 // Reference: Full jit
39 std::stringstream ms;
40 m.save(ms);
41 auto mm = load(ms);
42 // mm.train();
43 std::vector<::at::Tensor> parameters;
44 for (auto parameter : mm.parameters()) {
45 parameters.emplace_back(parameter);
46 }
47 ::torch::optim::SGD optimizer(
48 parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
49 for (int epoc = 0; epoc < n_epoc; ++epoc) {
50 for (auto& data : trainData) {
51 auto source = data.first, targets = data.second;
52 optimizer.zero_grad();
53 std::vector<IValue> train_inputs{source};
54 auto output = mm.forward(train_inputs).toTensor();
55 auto loss = ::torch::l1_loss(output, targets);
56 loss.backward();
57 optimizer.step();
58 }
59 }
60 std::stringstream ss;
61 m._save_for_mobile(ss);
62 mobile::Module bc = _load_for_mobile(ss);
63 std::vector<::at::Tensor> bc_parameters = bc.parameters();
64 ::torch::optim::SGD bc_optimizer(
65 bc_parameters,
66 ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
67 for (int epoc = 0; epoc < n_epoc; ++epoc) {
68 for (auto& data : trainData) {
69 auto source = data.first, targets = data.second;
70 bc_optimizer.zero_grad();
71 std::vector<IValue> train_inputs{source};
72 auto output = bc.forward(train_inputs).toTensor();
73 auto loss = ::torch::l1_loss(output, targets);
74 loss.backward();
75 bc_optimizer.step();
76 }
77 }
78 AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
79 }
80
81 // TODO Renable these tests after parameters are correctly loaded on mobile
82 /*
83 TEST(MobileTest, NamedParameters) {
84 Module m("m");
85 m.register_parameter("foo", torch::ones({}), false);
86 m.define(R"(
87 def add_it(self, x):
88 b = 4
89 return self.foo + x + b
90 )");
91 Module child("m2");
92 child.register_parameter("foo", 4 * torch::ones({}), false);
93 child.register_parameter("bar", 4 * torch::ones({}), false);
94 m.register_module("child1", child);
95 m.register_module("child2", child.clone());
96 std::stringstream ss;
97 m._save_for_mobile(ss);
98 mobile::Module bc = _load_for_mobile(ss);
99
100 auto full_params = m.named_parameters();
101 auto mobile_params = bc.named_parameters();
102 AT_ASSERT(full_params.size() == mobile_params.size());
103 for (const auto& e : full_params) {
104 AT_ASSERT(e.value.item().toInt() ==
105 mobile_params[e.name].item().toInt());
106 }
107 }
108
109 TEST(MobileTest, SaveLoadParameters) {
110 Module m("m");
111 m.register_parameter("foo", torch::ones({}), false);
112 m.define(R"(
113 def add_it(self, x):
114 b = 4
115 return self.foo + x + b
116 )");
117 Module child("m2");
118 child.register_parameter("foo", 4 * torch::ones({}), false);
119 child.register_parameter("bar", 3 * torch::ones({}), false);
120 m.register_module("child1", child);
121 m.register_module("child2", child.clone());
122 auto full_params = m.named_parameters();
123 std::stringstream ss;
124 std::stringstream ss_data;
125 m._save_for_mobile(ss);
126
127 // load mobile module, save mobile named parameters
128 mobile::Module bc = _load_for_mobile(ss);
129 _save_parameters(bc.named_parameters(), ss_data);
130
131 // load back the named parameters, compare to full-jit Module's
132 auto mobile_params = _load_parameters(ss_data);
133 AT_ASSERT(full_params.size() == mobile_params.size());
134 for (const auto& e : full_params) {
135 AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
136 }
137 }
138 */
139
TEST(MobileTest,SaveLoadParametersEmpty)140 TEST(MobileTest, SaveLoadParametersEmpty) {
141 Module m("m");
142 m.define(R"(
143 def add_it(self, x):
144 b = 4
145 return x + b
146 )");
147 Module child("m2");
148 m.register_module("child1", child);
149 m.register_module("child2", child.clone());
150 std::stringstream ss;
151 std::stringstream ss_data;
152 m._save_for_mobile(ss);
153
154 // load mobile module, save mobile named parameters
155 mobile::Module bc = _load_for_mobile(ss);
156 _save_parameters(bc.named_parameters(), ss_data);
157
158 // load back the named parameters, test is empty
159 auto mobile_params = _load_parameters(ss_data);
160 AT_ASSERT(mobile_params.size() == 0);
161 }
162
TEST(MobileTest,SaveParametersDefaultsToZip)163 TEST(MobileTest, SaveParametersDefaultsToZip) {
164 // Save some empty parameters.
165 std::map<std::string, at::Tensor> empty_parameters;
166 std::stringstream ss_data;
167 _save_parameters(empty_parameters, ss_data);
168
169 // Verify that parameters were serialized to a ZIP container.
170 EXPECT_GE(ss_data.str().size(), 4);
171 EXPECT_EQ(ss_data.str()[0], 'P');
172 EXPECT_EQ(ss_data.str()[1], 'K');
173 EXPECT_EQ(ss_data.str()[2], '\x03');
174 EXPECT_EQ(ss_data.str()[3], '\x04');
175 }
176
TEST(MobileTest,SaveParametersCanUseFlatbuffer)177 TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
178 // Save some empty parameters using flatbuffer.
179 std::map<std::string, at::Tensor> empty_parameters;
180 std::stringstream ss_data;
181 _save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
182
183 // Verify that parameters were serialized to a flatbuffer. The flatbuffer
184 // magic bytes should be at offsets 4..7. The first four bytes contain an
185 // offset to the actual flatbuffer data.
186 EXPECT_GE(ss_data.str().size(), 8);
187 EXPECT_EQ(ss_data.str()[4], 'P');
188 EXPECT_EQ(ss_data.str()[5], 'T');
189 EXPECT_EQ(ss_data.str()[6], 'M');
190 EXPECT_EQ(ss_data.str()[7], 'F');
191 }
192
TEST(MobileTest,SaveLoadParametersUsingFlatbuffers)193 TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
194 // Create some simple parameters to save.
195 std::map<std::string, at::Tensor> input_params;
196 input_params["four_by_ones"] = 4 * torch::ones({});
197 input_params["three_by_ones"] = 3 * torch::ones({});
198
199 // Serialize them using flatbuffers.
200 std::stringstream data;
201 _save_parameters(input_params, data, /*use_flatbuffer=*/true);
202
203 // The flatbuffer magic bytes should be at offsets 4..7.
204 EXPECT_EQ(data.str()[4], 'P');
205 EXPECT_EQ(data.str()[5], 'T');
206 EXPECT_EQ(data.str()[6], 'M');
207 EXPECT_EQ(data.str()[7], 'F');
208
209 // Read them back and check that they survived the trip.
210 auto output_params = _load_parameters(data);
211 EXPECT_EQ(output_params.size(), 2);
212 {
213 auto four_by_ones = 4 * torch::ones({});
214 EXPECT_EQ(
215 output_params["four_by_ones"].item<int>(), four_by_ones.item<int>());
216 }
217 {
218 auto three_by_ones = 3 * torch::ones({});
219 EXPECT_EQ(
220 output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
221 }
222 }
223
TEST(MobileTest,LoadParametersUnexpectedFormatShouldThrow)224 TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
225 // Manually create some data that doesn't look like a ZIP or Flatbuffer file.
226 // Make sure it's longer than 8 bytes, since getFileFormat() needs that much
227 // data to detect the type.
228 std::stringstream bad_data;
229 bad_data << "abcd"
230 << "efgh"
231 << "ijkl";
232
233 // Loading parameters from it should throw an exception.
234 EXPECT_ANY_THROW(_load_parameters(bad_data));
235 }
236
TEST(MobileTest,LoadParametersEmptyDataShouldThrow)237 TEST(MobileTest, LoadParametersEmptyDataShouldThrow) {
238 // Loading parameters from an empty data stream should throw an exception.
239 std::stringstream empty;
240 EXPECT_ANY_THROW(_load_parameters(empty));
241 }
242
TEST(MobileTest,LoadParametersMalformedFlatbuffer)243 TEST(MobileTest, LoadParametersMalformedFlatbuffer) {
244 // Manually create some data with Flatbuffer header.
245 std::stringstream bad_data;
246 bad_data << "PK\x03\x04PTMF\x00\x00"
247 << "*}NV\xb3\xfa\xdf\x00pa";
248
249 // Loading parameters from it should throw an exception.
250 ASSERT_THROWS_WITH_MESSAGE(
251 _load_parameters(bad_data), "Malformed Flatbuffer module");
252 }
253
TEST(LiteTrainerTest,SGD)254 TEST(LiteTrainerTest, SGD) {
255 Module m("m");
256 m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
257 m.define(R"(
258 def forward(self, x):
259 b = 1.0
260 return self.foo * x + b
261 )");
262 double learning_rate = 0.1, momentum = 0.1;
263 int n_epoc = 10;
264 // init: y = x + 1;
265 // target: y = 2 x + 1
266 std::vector<std::pair<Tensor, Tensor>> trainData{
267 {1 * torch::ones({1}), 3 * torch::ones({1})},
268 };
269 // Reference: Full jit and torch::optim::SGD
270 std::stringstream ms;
271 m.save(ms);
272 auto mm = load(ms);
273 std::vector<::at::Tensor> parameters;
274 for (auto parameter : mm.parameters()) {
275 parameters.emplace_back(parameter);
276 }
277 ::torch::optim::SGD optimizer(
278 parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
279 for (int epoc = 0; epoc < n_epoc; ++epoc) {
280 for (auto& data : trainData) {
281 auto source = data.first, targets = data.second;
282 optimizer.zero_grad();
283 std::vector<IValue> train_inputs{source};
284 auto output = mm.forward(train_inputs).toTensor();
285 auto loss = ::torch::l1_loss(output, targets);
286 loss.backward();
287 optimizer.step();
288 }
289 }
290 // Test: lite interpreter and torch::jit::mobile::SGD
291 std::stringstream ss;
292 m._save_for_mobile(ss);
293 mobile::Module bc = _load_for_mobile(ss);
294 std::vector<::at::Tensor> bc_parameters = bc.parameters();
295 ::torch::jit::mobile::SGD bc_optimizer(
296 bc_parameters,
297 ::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum));
298 for (int epoc = 0; epoc < n_epoc; ++epoc) {
299 for (auto& data : trainData) {
300 auto source = data.first, targets = data.second;
301 bc_optimizer.zero_grad();
302 std::vector<IValue> train_inputs{source};
303 auto output = bc.forward(train_inputs).toTensor();
304 auto loss = ::torch::l1_loss(output, targets);
305 loss.backward();
306 bc_optimizer.step();
307 }
308 }
309 AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
310 }
311
312 namespace {
313 struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
DummyDatasettorch::jit::__anonf54ddf480111::DummyDataset314 explicit DummyDataset(size_t size = 100) : size_(size) {}
315
gettorch::jit::__anonf54ddf480111::DummyDataset316 int get(size_t index) override {
317 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
318 return 1 + index;
319 }
sizetorch::jit::__anonf54ddf480111::DummyDataset320 torch::optional<size_t> size() const override {
321 return size_;
322 }
323
324 size_t size_;
325 };
326 } // namespace
327
TEST(LiteTrainerTest,SequentialSampler)328 TEST(LiteTrainerTest, SequentialSampler) {
329 // test that sampler can be used with dataloader
330 const int kBatchSize = 10;
331 auto data_loader = torch::data::make_data_loader<mobile::SequentialSampler>(
332 DummyDataset(25), kBatchSize);
333 int i = 1;
334 for (const auto& batch : *data_loader) {
335 for (const auto& example : batch) {
336 AT_ASSERT(i == example);
337 i++;
338 }
339 }
340 }
341
TEST(LiteTrainerTest,RandomSamplerReturnsIndicesInCorrectRange)342 TEST(LiteTrainerTest, RandomSamplerReturnsIndicesInCorrectRange) {
343 mobile::RandomSampler sampler(10);
344
345 std::vector<size_t> indices = sampler.next(3).value();
346 for (auto i : indices) {
347 AT_ASSERT(i < 10);
348 }
349
350 indices = sampler.next(5).value();
351 for (auto i : indices) {
352 AT_ASSERT(i < 10);
353 }
354
355 indices = sampler.next(2).value();
356 for (auto i : indices) {
357 AT_ASSERT(i < 10);
358 }
359
360 AT_ASSERT(sampler.next(10).has_value() == false);
361 }
362
TEST(LiteTrainerTest,RandomSamplerReturnsLessValuesForLastBatch)363 TEST(LiteTrainerTest, RandomSamplerReturnsLessValuesForLastBatch) {
364 mobile::RandomSampler sampler(5);
365 AT_ASSERT(sampler.next(3).value().size() == 3);
366 AT_ASSERT(sampler.next(100).value().size() == 2);
367 AT_ASSERT(sampler.next(2).has_value() == false);
368 }
369
TEST(LiteTrainerTest,RandomSamplerResetsWell)370 TEST(LiteTrainerTest, RandomSamplerResetsWell) {
371 mobile::RandomSampler sampler(5);
372 AT_ASSERT(sampler.next(5).value().size() == 5);
373 AT_ASSERT(sampler.next(2).has_value() == false);
374 sampler.reset();
375 AT_ASSERT(sampler.next(5).value().size() == 5);
376 AT_ASSERT(sampler.next(2).has_value() == false);
377 }
378
TEST(LiteTrainerTest,RandomSamplerResetsWithNewSizeWell)379 TEST(LiteTrainerTest, RandomSamplerResetsWithNewSizeWell) {
380 mobile::RandomSampler sampler(5);
381 AT_ASSERT(sampler.next(5).value().size() == 5);
382 AT_ASSERT(sampler.next(2).has_value() == false);
383 sampler.reset(7);
384 AT_ASSERT(sampler.next(7).value().size() == 7);
385 AT_ASSERT(sampler.next(2).has_value() == false);
386 sampler.reset(3);
387 AT_ASSERT(sampler.next(3).value().size() == 3);
388 AT_ASSERT(sampler.next(2).has_value() == false);
389 }
390
391 } // namespace jit
392 } // namespace torch
393