xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_lite_trainer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <test/cpp/jit/test_utils.h>
2 
3 #include <gtest/gtest.h>
4 
5 #include <c10/core/TensorOptions.h>
6 #include <torch/csrc/autograd/generated/variable_factories.h>
7 #include <torch/csrc/jit/api/module.h>
8 #include <torch/csrc/jit/mobile/import.h>
9 #include <torch/csrc/jit/mobile/import_data.h>
10 #include <torch/csrc/jit/mobile/module.h>
11 #include <torch/csrc/jit/mobile/train/export_data.h>
12 #include <torch/csrc/jit/mobile/train/optim/sgd.h>
13 #include <torch/csrc/jit/mobile/train/random.h>
14 #include <torch/csrc/jit/mobile/train/sequential.h>
15 #include <torch/csrc/jit/serialization/import.h>
16 #include <torch/data/dataloader.h>
17 #include <torch/torch.h>
18 
19 // Tests go in torch::jit
20 namespace torch {
21 namespace jit {
22 
TEST(LiteTrainerTest,Params)23 TEST(LiteTrainerTest, Params) {
24   Module m("m");
25   m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
26   m.define(R"(
27     def forward(self, x):
28       b = 1.0
29       return self.foo * x + b
30   )");
31   double learning_rate = 0.1, momentum = 0.1;
32   int n_epoc = 10;
33   // init: y = x + 1;
34   // target: y = 2 x + 1
35   std::vector<std::pair<Tensor, Tensor>> trainData{
36       {1 * torch::ones({1}), 3 * torch::ones({1})},
37   };
38   // Reference: Full jit
39   std::stringstream ms;
40   m.save(ms);
41   auto mm = load(ms);
42   //  mm.train();
43   std::vector<::at::Tensor> parameters;
44   for (auto parameter : mm.parameters()) {
45     parameters.emplace_back(parameter);
46   }
47   ::torch::optim::SGD optimizer(
48       parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
49   for (int epoc = 0; epoc < n_epoc; ++epoc) {
50     for (auto& data : trainData) {
51       auto source = data.first, targets = data.second;
52       optimizer.zero_grad();
53       std::vector<IValue> train_inputs{source};
54       auto output = mm.forward(train_inputs).toTensor();
55       auto loss = ::torch::l1_loss(output, targets);
56       loss.backward();
57       optimizer.step();
58     }
59   }
60   std::stringstream ss;
61   m._save_for_mobile(ss);
62   mobile::Module bc = _load_for_mobile(ss);
63   std::vector<::at::Tensor> bc_parameters = bc.parameters();
64   ::torch::optim::SGD bc_optimizer(
65       bc_parameters,
66       ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
67   for (int epoc = 0; epoc < n_epoc; ++epoc) {
68     for (auto& data : trainData) {
69       auto source = data.first, targets = data.second;
70       bc_optimizer.zero_grad();
71       std::vector<IValue> train_inputs{source};
72       auto output = bc.forward(train_inputs).toTensor();
73       auto loss = ::torch::l1_loss(output, targets);
74       loss.backward();
75       bc_optimizer.step();
76     }
77   }
78   AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
79 }
80 
81 // TODO Renable these tests after parameters are correctly loaded on mobile
82 /*
83 TEST(MobileTest, NamedParameters) {
84   Module m("m");
85   m.register_parameter("foo", torch::ones({}), false);
86   m.define(R"(
87     def add_it(self, x):
88       b = 4
89       return self.foo + x + b
90   )");
91   Module child("m2");
92   child.register_parameter("foo", 4 * torch::ones({}), false);
93   child.register_parameter("bar", 4 * torch::ones({}), false);
94   m.register_module("child1", child);
95   m.register_module("child2", child.clone());
96   std::stringstream ss;
97   m._save_for_mobile(ss);
98   mobile::Module bc = _load_for_mobile(ss);
99 
100   auto full_params = m.named_parameters();
101   auto mobile_params = bc.named_parameters();
102   AT_ASSERT(full_params.size() == mobile_params.size());
103   for (const auto& e : full_params) {
104     AT_ASSERT(e.value.item().toInt() ==
105     mobile_params[e.name].item().toInt());
106   }
107 }
108 
109 TEST(MobileTest, SaveLoadParameters) {
110   Module m("m");
111   m.register_parameter("foo", torch::ones({}), false);
112   m.define(R"(
113     def add_it(self, x):
114       b = 4
115       return self.foo + x + b
116   )");
117   Module child("m2");
118   child.register_parameter("foo", 4 * torch::ones({}), false);
119   child.register_parameter("bar", 3 * torch::ones({}), false);
120   m.register_module("child1", child);
121   m.register_module("child2", child.clone());
122   auto full_params = m.named_parameters();
123   std::stringstream ss;
124   std::stringstream ss_data;
125   m._save_for_mobile(ss);
126 
127   // load mobile module, save mobile named parameters
128   mobile::Module bc = _load_for_mobile(ss);
129   _save_parameters(bc.named_parameters(), ss_data);
130 
131   // load back the named parameters, compare to full-jit Module's
132   auto mobile_params = _load_parameters(ss_data);
133   AT_ASSERT(full_params.size() == mobile_params.size());
134   for (const auto& e : full_params) {
135     AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>());
136   }
137 }
138 */
139 
TEST(MobileTest,SaveLoadParametersEmpty)140 TEST(MobileTest, SaveLoadParametersEmpty) {
141   Module m("m");
142   m.define(R"(
143     def add_it(self, x):
144       b = 4
145       return x + b
146   )");
147   Module child("m2");
148   m.register_module("child1", child);
149   m.register_module("child2", child.clone());
150   std::stringstream ss;
151   std::stringstream ss_data;
152   m._save_for_mobile(ss);
153 
154   // load mobile module, save mobile named parameters
155   mobile::Module bc = _load_for_mobile(ss);
156   _save_parameters(bc.named_parameters(), ss_data);
157 
158   // load back the named parameters, test is empty
159   auto mobile_params = _load_parameters(ss_data);
160   AT_ASSERT(mobile_params.size() == 0);
161 }
162 
TEST(MobileTest,SaveParametersDefaultsToZip)163 TEST(MobileTest, SaveParametersDefaultsToZip) {
164   // Save some empty parameters.
165   std::map<std::string, at::Tensor> empty_parameters;
166   std::stringstream ss_data;
167   _save_parameters(empty_parameters, ss_data);
168 
169   // Verify that parameters were serialized to a ZIP container.
170   EXPECT_GE(ss_data.str().size(), 4);
171   EXPECT_EQ(ss_data.str()[0], 'P');
172   EXPECT_EQ(ss_data.str()[1], 'K');
173   EXPECT_EQ(ss_data.str()[2], '\x03');
174   EXPECT_EQ(ss_data.str()[3], '\x04');
175 }
176 
TEST(MobileTest,SaveParametersCanUseFlatbuffer)177 TEST(MobileTest, SaveParametersCanUseFlatbuffer) {
178   // Save some empty parameters using flatbuffer.
179   std::map<std::string, at::Tensor> empty_parameters;
180   std::stringstream ss_data;
181   _save_parameters(empty_parameters, ss_data, /*use_flatbuffer=*/true);
182 
183   // Verify that parameters were serialized to a flatbuffer. The flatbuffer
184   // magic bytes should be at offsets 4..7. The first four bytes contain an
185   // offset to the actual flatbuffer data.
186   EXPECT_GE(ss_data.str().size(), 8);
187   EXPECT_EQ(ss_data.str()[4], 'P');
188   EXPECT_EQ(ss_data.str()[5], 'T');
189   EXPECT_EQ(ss_data.str()[6], 'M');
190   EXPECT_EQ(ss_data.str()[7], 'F');
191 }
192 
TEST(MobileTest,SaveLoadParametersUsingFlatbuffers)193 TEST(MobileTest, SaveLoadParametersUsingFlatbuffers) {
194   // Create some simple parameters to save.
195   std::map<std::string, at::Tensor> input_params;
196   input_params["four_by_ones"] = 4 * torch::ones({});
197   input_params["three_by_ones"] = 3 * torch::ones({});
198 
199   // Serialize them using flatbuffers.
200   std::stringstream data;
201   _save_parameters(input_params, data, /*use_flatbuffer=*/true);
202 
203   // The flatbuffer magic bytes should be at offsets 4..7.
204   EXPECT_EQ(data.str()[4], 'P');
205   EXPECT_EQ(data.str()[5], 'T');
206   EXPECT_EQ(data.str()[6], 'M');
207   EXPECT_EQ(data.str()[7], 'F');
208 
209   // Read them back and check that they survived the trip.
210   auto output_params = _load_parameters(data);
211   EXPECT_EQ(output_params.size(), 2);
212   {
213     auto four_by_ones = 4 * torch::ones({});
214     EXPECT_EQ(
215         output_params["four_by_ones"].item<int>(), four_by_ones.item<int>());
216   }
217   {
218     auto three_by_ones = 3 * torch::ones({});
219     EXPECT_EQ(
220         output_params["three_by_ones"].item<int>(), three_by_ones.item<int>());
221   }
222 }
223 
TEST(MobileTest,LoadParametersUnexpectedFormatShouldThrow)224 TEST(MobileTest, LoadParametersUnexpectedFormatShouldThrow) {
225   // Manually create some data that doesn't look like a ZIP or Flatbuffer file.
226   // Make sure it's longer than 8 bytes, since getFileFormat() needs that much
227   // data to detect the type.
228   std::stringstream bad_data;
229   bad_data << "abcd"
230            << "efgh"
231            << "ijkl";
232 
233   // Loading parameters from it should throw an exception.
234   EXPECT_ANY_THROW(_load_parameters(bad_data));
235 }
236 
TEST(MobileTest,LoadParametersEmptyDataShouldThrow)237 TEST(MobileTest, LoadParametersEmptyDataShouldThrow) {
238   // Loading parameters from an empty data stream should throw an exception.
239   std::stringstream empty;
240   EXPECT_ANY_THROW(_load_parameters(empty));
241 }
242 
TEST(MobileTest,LoadParametersMalformedFlatbuffer)243 TEST(MobileTest, LoadParametersMalformedFlatbuffer) {
244   // Manually create some data with Flatbuffer header.
245   std::stringstream bad_data;
246   bad_data << "PK\x03\x04PTMF\x00\x00"
247            << "*}NV\xb3\xfa\xdf\x00pa";
248 
249   // Loading parameters from it should throw an exception.
250   ASSERT_THROWS_WITH_MESSAGE(
251       _load_parameters(bad_data), "Malformed Flatbuffer module");
252 }
253 
TEST(LiteTrainerTest,SGD)254 TEST(LiteTrainerTest, SGD) {
255   Module m("m");
256   m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false);
257   m.define(R"(
258     def forward(self, x):
259       b = 1.0
260       return self.foo * x + b
261   )");
262   double learning_rate = 0.1, momentum = 0.1;
263   int n_epoc = 10;
264   // init: y = x + 1;
265   // target: y = 2 x + 1
266   std::vector<std::pair<Tensor, Tensor>> trainData{
267       {1 * torch::ones({1}), 3 * torch::ones({1})},
268   };
269   // Reference: Full jit and torch::optim::SGD
270   std::stringstream ms;
271   m.save(ms);
272   auto mm = load(ms);
273   std::vector<::at::Tensor> parameters;
274   for (auto parameter : mm.parameters()) {
275     parameters.emplace_back(parameter);
276   }
277   ::torch::optim::SGD optimizer(
278       parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum));
279   for (int epoc = 0; epoc < n_epoc; ++epoc) {
280     for (auto& data : trainData) {
281       auto source = data.first, targets = data.second;
282       optimizer.zero_grad();
283       std::vector<IValue> train_inputs{source};
284       auto output = mm.forward(train_inputs).toTensor();
285       auto loss = ::torch::l1_loss(output, targets);
286       loss.backward();
287       optimizer.step();
288     }
289   }
290   // Test: lite interpreter and torch::jit::mobile::SGD
291   std::stringstream ss;
292   m._save_for_mobile(ss);
293   mobile::Module bc = _load_for_mobile(ss);
294   std::vector<::at::Tensor> bc_parameters = bc.parameters();
295   ::torch::jit::mobile::SGD bc_optimizer(
296       bc_parameters,
297       ::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum));
298   for (int epoc = 0; epoc < n_epoc; ++epoc) {
299     for (auto& data : trainData) {
300       auto source = data.first, targets = data.second;
301       bc_optimizer.zero_grad();
302       std::vector<IValue> train_inputs{source};
303       auto output = bc.forward(train_inputs).toTensor();
304       auto loss = ::torch::l1_loss(output, targets);
305       loss.backward();
306       bc_optimizer.step();
307     }
308   }
309   AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>());
310 }
311 
312 namespace {
313 struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> {
DummyDatasettorch::jit::__anonf54ddf480111::DummyDataset314   explicit DummyDataset(size_t size = 100) : size_(size) {}
315 
gettorch::jit::__anonf54ddf480111::DummyDataset316   int get(size_t index) override {
317     // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
318     return 1 + index;
319   }
sizetorch::jit::__anonf54ddf480111::DummyDataset320   torch::optional<size_t> size() const override {
321     return size_;
322   }
323 
324   size_t size_;
325 };
326 } // namespace
327 
TEST(LiteTrainerTest,SequentialSampler)328 TEST(LiteTrainerTest, SequentialSampler) {
329   // test that sampler can be used with dataloader
330   const int kBatchSize = 10;
331   auto data_loader = torch::data::make_data_loader<mobile::SequentialSampler>(
332       DummyDataset(25), kBatchSize);
333   int i = 1;
334   for (const auto& batch : *data_loader) {
335     for (const auto& example : batch) {
336       AT_ASSERT(i == example);
337       i++;
338     }
339   }
340 }
341 
TEST(LiteTrainerTest,RandomSamplerReturnsIndicesInCorrectRange)342 TEST(LiteTrainerTest, RandomSamplerReturnsIndicesInCorrectRange) {
343   mobile::RandomSampler sampler(10);
344 
345   std::vector<size_t> indices = sampler.next(3).value();
346   for (auto i : indices) {
347     AT_ASSERT(i < 10);
348   }
349 
350   indices = sampler.next(5).value();
351   for (auto i : indices) {
352     AT_ASSERT(i < 10);
353   }
354 
355   indices = sampler.next(2).value();
356   for (auto i : indices) {
357     AT_ASSERT(i < 10);
358   }
359 
360   AT_ASSERT(sampler.next(10).has_value() == false);
361 }
362 
TEST(LiteTrainerTest,RandomSamplerReturnsLessValuesForLastBatch)363 TEST(LiteTrainerTest, RandomSamplerReturnsLessValuesForLastBatch) {
364   mobile::RandomSampler sampler(5);
365   AT_ASSERT(sampler.next(3).value().size() == 3);
366   AT_ASSERT(sampler.next(100).value().size() == 2);
367   AT_ASSERT(sampler.next(2).has_value() == false);
368 }
369 
TEST(LiteTrainerTest,RandomSamplerResetsWell)370 TEST(LiteTrainerTest, RandomSamplerResetsWell) {
371   mobile::RandomSampler sampler(5);
372   AT_ASSERT(sampler.next(5).value().size() == 5);
373   AT_ASSERT(sampler.next(2).has_value() == false);
374   sampler.reset();
375   AT_ASSERT(sampler.next(5).value().size() == 5);
376   AT_ASSERT(sampler.next(2).has_value() == false);
377 }
378 
TEST(LiteTrainerTest,RandomSamplerResetsWithNewSizeWell)379 TEST(LiteTrainerTest, RandomSamplerResetsWithNewSizeWell) {
380   mobile::RandomSampler sampler(5);
381   AT_ASSERT(sampler.next(5).value().size() == 5);
382   AT_ASSERT(sampler.next(2).has_value() == false);
383   sampler.reset(7);
384   AT_ASSERT(sampler.next(7).value().size() == 7);
385   AT_ASSERT(sampler.next(2).has_value() == false);
386   sampler.reset(3);
387   AT_ASSERT(sampler.next(3).value().size() == 3);
388   AT_ASSERT(sampler.next(2).has_value() == false);
389 }
390 
391 } // namespace jit
392 } // namespace torch
393