1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker #include <cmath>
9*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
10*da0073e9SAndroid Build Coastguard Worker #include <random>
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
13*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker const double kPi = 3.1415926535898;
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker class CartPole {
18*da0073e9SAndroid Build Coastguard Worker // Translated from openai/gym's cartpole.py
19*da0073e9SAndroid Build Coastguard Worker public:
20*da0073e9SAndroid Build Coastguard Worker double gravity = 9.8;
21*da0073e9SAndroid Build Coastguard Worker double masscart = 1.0;
22*da0073e9SAndroid Build Coastguard Worker double masspole = 0.1;
23*da0073e9SAndroid Build Coastguard Worker double total_mass = (masspole + masscart);
24*da0073e9SAndroid Build Coastguard Worker double length = 0.5; // actually half the pole's length;
25*da0073e9SAndroid Build Coastguard Worker double polemass_length = (masspole * length);
26*da0073e9SAndroid Build Coastguard Worker double force_mag = 10.0;
27*da0073e9SAndroid Build Coastguard Worker double tau = 0.02; // seconds between state updates;
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker // Angle at which to fail the episode
30*da0073e9SAndroid Build Coastguard Worker double theta_threshold_radians = 12 * 2 * kPi / 360;
31*da0073e9SAndroid Build Coastguard Worker double x_threshold = 2.4;
32*da0073e9SAndroid Build Coastguard Worker int steps_beyond_done = -1;
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker torch::Tensor state;
35*da0073e9SAndroid Build Coastguard Worker double reward;
36*da0073e9SAndroid Build Coastguard Worker bool done;
37*da0073e9SAndroid Build Coastguard Worker int step_ = 0;
38*da0073e9SAndroid Build Coastguard Worker
getState()39*da0073e9SAndroid Build Coastguard Worker torch::Tensor getState() {
40*da0073e9SAndroid Build Coastguard Worker return state;
41*da0073e9SAndroid Build Coastguard Worker }
42*da0073e9SAndroid Build Coastguard Worker
getReward()43*da0073e9SAndroid Build Coastguard Worker double getReward() {
44*da0073e9SAndroid Build Coastguard Worker return reward;
45*da0073e9SAndroid Build Coastguard Worker }
46*da0073e9SAndroid Build Coastguard Worker
isDone()47*da0073e9SAndroid Build Coastguard Worker double isDone() {
48*da0073e9SAndroid Build Coastguard Worker return done;
49*da0073e9SAndroid Build Coastguard Worker }
50*da0073e9SAndroid Build Coastguard Worker
reset()51*da0073e9SAndroid Build Coastguard Worker void reset() {
52*da0073e9SAndroid Build Coastguard Worker state = torch::empty({4}).uniform_(-0.05, 0.05);
53*da0073e9SAndroid Build Coastguard Worker steps_beyond_done = -1;
54*da0073e9SAndroid Build Coastguard Worker step_ = 0;
55*da0073e9SAndroid Build Coastguard Worker }
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CartPole()58*da0073e9SAndroid Build Coastguard Worker CartPole() {
59*da0073e9SAndroid Build Coastguard Worker reset();
60*da0073e9SAndroid Build Coastguard Worker }
61*da0073e9SAndroid Build Coastguard Worker
step(int action)62*da0073e9SAndroid Build Coastguard Worker void step(int action) {
63*da0073e9SAndroid Build Coastguard Worker auto x = state[0].item<float>();
64*da0073e9SAndroid Build Coastguard Worker auto x_dot = state[1].item<float>();
65*da0073e9SAndroid Build Coastguard Worker auto theta = state[2].item<float>();
66*da0073e9SAndroid Build Coastguard Worker auto theta_dot = state[3].item<float>();
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker auto force = (action == 1) ? force_mag : -force_mag;
69*da0073e9SAndroid Build Coastguard Worker auto costheta = std::cos(theta);
70*da0073e9SAndroid Build Coastguard Worker auto sintheta = std::sin(theta);
71*da0073e9SAndroid Build Coastguard Worker auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) /
72*da0073e9SAndroid Build Coastguard Worker total_mass;
73*da0073e9SAndroid Build Coastguard Worker auto thetaacc = (gravity * sintheta - costheta * temp) /
74*da0073e9SAndroid Build Coastguard Worker (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass));
75*da0073e9SAndroid Build Coastguard Worker auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker x = x + tau * x_dot;
78*da0073e9SAndroid Build Coastguard Worker x_dot = x_dot + tau * xacc;
79*da0073e9SAndroid Build Coastguard Worker theta = theta + tau * theta_dot;
80*da0073e9SAndroid Build Coastguard Worker theta_dot = theta_dot + tau * thetaacc;
81*da0073e9SAndroid Build Coastguard Worker state = torch::tensor({x, x_dot, theta, theta_dot});
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker done = x < -x_threshold || x > x_threshold ||
84*da0073e9SAndroid Build Coastguard Worker theta < -theta_threshold_radians || theta > theta_threshold_radians ||
85*da0073e9SAndroid Build Coastguard Worker step_ > 200;
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker if (!done) {
88*da0073e9SAndroid Build Coastguard Worker reward = 1.0;
89*da0073e9SAndroid Build Coastguard Worker } else if (steps_beyond_done == -1) {
90*da0073e9SAndroid Build Coastguard Worker // Pole just fell!
91*da0073e9SAndroid Build Coastguard Worker steps_beyond_done = 0;
92*da0073e9SAndroid Build Coastguard Worker reward = 0;
93*da0073e9SAndroid Build Coastguard Worker } else {
94*da0073e9SAndroid Build Coastguard Worker if (steps_beyond_done == 0) {
95*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(false); // Can't do this
96*da0073e9SAndroid Build Coastguard Worker }
97*da0073e9SAndroid Build Coastguard Worker }
98*da0073e9SAndroid Build Coastguard Worker step_++;
99*da0073e9SAndroid Build Coastguard Worker }
100*da0073e9SAndroid Build Coastguard Worker };
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker template <typename M, typename F, typename O>
test_mnist(size_t batch_size,size_t number_of_epochs,bool with_cuda,M && model,F && forward_op,O && optimizer)103*da0073e9SAndroid Build Coastguard Worker bool test_mnist(
104*da0073e9SAndroid Build Coastguard Worker size_t batch_size,
105*da0073e9SAndroid Build Coastguard Worker size_t number_of_epochs,
106*da0073e9SAndroid Build Coastguard Worker bool with_cuda,
107*da0073e9SAndroid Build Coastguard Worker M&& model,
108*da0073e9SAndroid Build Coastguard Worker F&& forward_op,
109*da0073e9SAndroid Build Coastguard Worker O&& optimizer) {
110*da0073e9SAndroid Build Coastguard Worker std::string mnist_path = "mnist";
111*da0073e9SAndroid Build Coastguard Worker if (const char* user_mnist_path = getenv("TORCH_CPP_TEST_MNIST_PATH")) {
112*da0073e9SAndroid Build Coastguard Worker mnist_path = user_mnist_path;
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker auto train_dataset =
116*da0073e9SAndroid Build Coastguard Worker torch::data::datasets::MNIST(
117*da0073e9SAndroid Build Coastguard Worker mnist_path, torch::data::datasets::MNIST::Mode::kTrain)
118*da0073e9SAndroid Build Coastguard Worker .map(torch::data::transforms::Stack<>());
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker auto data_loader =
121*da0073e9SAndroid Build Coastguard Worker torch::data::make_data_loader(std::move(train_dataset), batch_size);
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU);
124*da0073e9SAndroid Build Coastguard Worker model->to(device);
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker for (const auto epoch : c10::irange(number_of_epochs)) {
127*da0073e9SAndroid Build Coastguard Worker (void)epoch; // Suppress unused variable warning
128*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(performance-for-range-copy)
129*da0073e9SAndroid Build Coastguard Worker for (torch::data::Example<> batch : *data_loader) {
130*da0073e9SAndroid Build Coastguard Worker auto data = batch.data.to(device);
131*da0073e9SAndroid Build Coastguard Worker auto targets = batch.target.to(device);
132*da0073e9SAndroid Build Coastguard Worker torch::Tensor prediction = forward_op(std::move(data));
133*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(performance-move-const-arg)
134*da0073e9SAndroid Build Coastguard Worker torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
135*da0073e9SAndroid Build Coastguard Worker AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
136*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad();
137*da0073e9SAndroid Build Coastguard Worker loss.backward();
138*da0073e9SAndroid Build Coastguard Worker optimizer.step();
139*da0073e9SAndroid Build Coastguard Worker }
140*da0073e9SAndroid Build Coastguard Worker }
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker torch::NoGradGuard guard;
143*da0073e9SAndroid Build Coastguard Worker torch::data::datasets::MNIST test_dataset(
144*da0073e9SAndroid Build Coastguard Worker mnist_path, torch::data::datasets::MNIST::Mode::kTest);
145*da0073e9SAndroid Build Coastguard Worker auto images = test_dataset.images().to(device),
146*da0073e9SAndroid Build Coastguard Worker targets = test_dataset.targets().to(device);
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker auto result = std::get<1>(forward_op(images).max(/*dim=*/1));
149*da0073e9SAndroid Build Coastguard Worker torch::Tensor correct = (result == targets).to(torch::kFloat32);
150*da0073e9SAndroid Build Coastguard Worker return correct.sum().item<float>() > (test_dataset.size().value() * 0.8);
151*da0073e9SAndroid Build Coastguard Worker }
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker struct IntegrationTest : torch::test::SeedingFixture {};
154*da0073e9SAndroid Build Coastguard Worker
TEST_F(IntegrationTest,CartPole)155*da0073e9SAndroid Build Coastguard Worker TEST_F(IntegrationTest, CartPole) {
156*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
157*da0073e9SAndroid Build Coastguard Worker auto model = std::make_shared<SimpleContainer>();
158*da0073e9SAndroid Build Coastguard Worker auto linear = model->add(Linear(4, 128), "linear");
159*da0073e9SAndroid Build Coastguard Worker auto policyHead = model->add(Linear(128, 2), "policy");
160*da0073e9SAndroid Build Coastguard Worker auto valueHead = model->add(Linear(128, 1), "action");
161*da0073e9SAndroid Build Coastguard Worker auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> saved_log_probs;
164*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> saved_values;
165*da0073e9SAndroid Build Coastguard Worker std::vector<float> rewards;
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker auto forward = [&](torch::Tensor inp) {
168*da0073e9SAndroid Build Coastguard Worker auto x = linear->forward(inp).clamp_min(0);
169*da0073e9SAndroid Build Coastguard Worker torch::Tensor actions = policyHead->forward(x);
170*da0073e9SAndroid Build Coastguard Worker torch::Tensor value = valueHead->forward(x);
171*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(torch::softmax(actions, -1), value);
172*da0073e9SAndroid Build Coastguard Worker };
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker auto selectAction = [&](torch::Tensor state) {
175*da0073e9SAndroid Build Coastguard Worker // Only work on single state right now, change index to gather for batch
176*da0073e9SAndroid Build Coastguard Worker auto out = forward(state);
177*da0073e9SAndroid Build Coastguard Worker auto probs = torch::Tensor(std::get<0>(out));
178*da0073e9SAndroid Build Coastguard Worker auto value = torch::Tensor(std::get<1>(out));
179*da0073e9SAndroid Build Coastguard Worker auto action = probs.multinomial(1)[0].item<int32_t>();
180*da0073e9SAndroid Build Coastguard Worker // Compute the log prob of a multinomial distribution.
181*da0073e9SAndroid Build Coastguard Worker // This should probably be actually implemented in autogradpp...
182*da0073e9SAndroid Build Coastguard Worker auto p = probs / probs.sum(-1, true);
183*da0073e9SAndroid Build Coastguard Worker auto log_prob = p[action].log();
184*da0073e9SAndroid Build Coastguard Worker saved_log_probs.emplace_back(log_prob);
185*da0073e9SAndroid Build Coastguard Worker saved_values.push_back(value);
186*da0073e9SAndroid Build Coastguard Worker return action;
187*da0073e9SAndroid Build Coastguard Worker };
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker auto finishEpisode = [&] {
190*da0073e9SAndroid Build Coastguard Worker auto R = 0.;
191*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
192*da0073e9SAndroid Build Coastguard Worker for (int i = rewards.size() - 1; i >= 0; i--) {
193*da0073e9SAndroid Build Coastguard Worker R = rewards[i] + 0.99 * R;
194*da0073e9SAndroid Build Coastguard Worker rewards[i] = R;
195*da0073e9SAndroid Build Coastguard Worker }
196*da0073e9SAndroid Build Coastguard Worker auto r_t = torch::from_blob(
197*da0073e9SAndroid Build Coastguard Worker rewards.data(), {static_cast<int64_t>(rewards.size())});
198*da0073e9SAndroid Build Coastguard Worker r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> policy_loss;
201*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> value_loss;
202*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(0U, saved_log_probs.size())) {
203*da0073e9SAndroid Build Coastguard Worker auto advantage = r_t[i] - saved_values[i].item<float>();
204*da0073e9SAndroid Build Coastguard Worker policy_loss.push_back(-advantage * saved_log_probs[i]);
205*da0073e9SAndroid Build Coastguard Worker value_loss.push_back(
206*da0073e9SAndroid Build Coastguard Worker torch::smooth_l1_loss(saved_values[i], torch::ones(1) * r_t[i]));
207*da0073e9SAndroid Build Coastguard Worker }
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker auto loss =
210*da0073e9SAndroid Build Coastguard Worker torch::stack(policy_loss).sum() + torch::stack(value_loss).sum();
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker optimizer.zero_grad();
213*da0073e9SAndroid Build Coastguard Worker loss.backward();
214*da0073e9SAndroid Build Coastguard Worker optimizer.step();
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker rewards.clear();
217*da0073e9SAndroid Build Coastguard Worker saved_log_probs.clear();
218*da0073e9SAndroid Build Coastguard Worker saved_values.clear();
219*da0073e9SAndroid Build Coastguard Worker };
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker auto env = CartPole();
222*da0073e9SAndroid Build Coastguard Worker double running_reward = 10.0;
223*da0073e9SAndroid Build Coastguard Worker for (size_t episode = 0;; episode++) {
224*da0073e9SAndroid Build Coastguard Worker env.reset();
225*da0073e9SAndroid Build Coastguard Worker auto state = env.getState();
226*da0073e9SAndroid Build Coastguard Worker int t = 0;
227*da0073e9SAndroid Build Coastguard Worker for (; t < 10000; t++) {
228*da0073e9SAndroid Build Coastguard Worker auto action = selectAction(state);
229*da0073e9SAndroid Build Coastguard Worker env.step(action);
230*da0073e9SAndroid Build Coastguard Worker state = env.getState();
231*da0073e9SAndroid Build Coastguard Worker auto reward = env.getReward();
232*da0073e9SAndroid Build Coastguard Worker auto done = env.isDone();
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker rewards.push_back(reward);
235*da0073e9SAndroid Build Coastguard Worker if (done)
236*da0073e9SAndroid Build Coastguard Worker break;
237*da0073e9SAndroid Build Coastguard Worker }
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker running_reward = running_reward * 0.99 + t * 0.01;
240*da0073e9SAndroid Build Coastguard Worker finishEpisode();
241*da0073e9SAndroid Build Coastguard Worker /*
242*da0073e9SAndroid Build Coastguard Worker if (episode % 10 == 0) {
243*da0073e9SAndroid Build Coastguard Worker printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
244*da0073e9SAndroid Build Coastguard Worker episode, t, running_reward);
245*da0073e9SAndroid Build Coastguard Worker }
246*da0073e9SAndroid Build Coastguard Worker */
247*da0073e9SAndroid Build Coastguard Worker if (running_reward > 150) {
248*da0073e9SAndroid Build Coastguard Worker break;
249*da0073e9SAndroid Build Coastguard Worker }
250*da0073e9SAndroid Build Coastguard Worker ASSERT_LT(episode, 3000);
251*da0073e9SAndroid Build Coastguard Worker }
252*da0073e9SAndroid Build Coastguard Worker }
253*da0073e9SAndroid Build Coastguard Worker
TEST_F(IntegrationTest,MNIST_CUDA)254*da0073e9SAndroid Build Coastguard Worker TEST_F(IntegrationTest, MNIST_CUDA) {
255*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
256*da0073e9SAndroid Build Coastguard Worker auto model = std::make_shared<SimpleContainer>();
257*da0073e9SAndroid Build Coastguard Worker auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
258*da0073e9SAndroid Build Coastguard Worker auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
259*da0073e9SAndroid Build Coastguard Worker auto drop = Dropout(0.3);
260*da0073e9SAndroid Build Coastguard Worker auto drop2d = Dropout2d(0.3);
261*da0073e9SAndroid Build Coastguard Worker auto linear1 = model->add(Linear(320, 50), "linear1");
262*da0073e9SAndroid Build Coastguard Worker auto linear2 = model->add(Linear(50, 10), "linear2");
263*da0073e9SAndroid Build Coastguard Worker
264*da0073e9SAndroid Build Coastguard Worker auto forward = [&](torch::Tensor x) {
265*da0073e9SAndroid Build Coastguard Worker x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
266*da0073e9SAndroid Build Coastguard Worker x = conv2->forward(x);
267*da0073e9SAndroid Build Coastguard Worker x = drop2d->forward(x);
268*da0073e9SAndroid Build Coastguard Worker x = torch::max_pool2d(x, {2, 2}).relu();
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker x = x.view({-1, 320});
271*da0073e9SAndroid Build Coastguard Worker x = linear1->forward(x).clamp_min(0);
272*da0073e9SAndroid Build Coastguard Worker x = drop->forward(x);
273*da0073e9SAndroid Build Coastguard Worker x = linear2->forward(x);
274*da0073e9SAndroid Build Coastguard Worker x = torch::log_softmax(x, 1);
275*da0073e9SAndroid Build Coastguard Worker return x;
276*da0073e9SAndroid Build Coastguard Worker };
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker auto optimizer = torch::optim::SGD(
279*da0073e9SAndroid Build Coastguard Worker model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_mnist(
282*da0073e9SAndroid Build Coastguard Worker 32, // batch_size
283*da0073e9SAndroid Build Coastguard Worker 3, // number_of_epochs
284*da0073e9SAndroid Build Coastguard Worker true, // with_cuda
285*da0073e9SAndroid Build Coastguard Worker model,
286*da0073e9SAndroid Build Coastguard Worker forward,
287*da0073e9SAndroid Build Coastguard Worker optimizer));
288*da0073e9SAndroid Build Coastguard Worker }
289*da0073e9SAndroid Build Coastguard Worker
TEST_F(IntegrationTest,MNISTBatchNorm_CUDA)290*da0073e9SAndroid Build Coastguard Worker TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
291*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
292*da0073e9SAndroid Build Coastguard Worker auto model = std::make_shared<SimpleContainer>();
293*da0073e9SAndroid Build Coastguard Worker auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
294*da0073e9SAndroid Build Coastguard Worker auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
295*da0073e9SAndroid Build Coastguard Worker auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
296*da0073e9SAndroid Build Coastguard Worker auto linear1 = model->add(Linear(320, 50), "linear1");
297*da0073e9SAndroid Build Coastguard Worker auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
298*da0073e9SAndroid Build Coastguard Worker auto linear2 = model->add(Linear(50, 10), "linear2");
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker auto forward = [&](torch::Tensor x) {
301*da0073e9SAndroid Build Coastguard Worker x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
302*da0073e9SAndroid Build Coastguard Worker x = batchnorm2d->forward(x);
303*da0073e9SAndroid Build Coastguard Worker x = conv2->forward(x);
304*da0073e9SAndroid Build Coastguard Worker x = torch::max_pool2d(x, {2, 2}).relu();
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker x = x.view({-1, 320});
307*da0073e9SAndroid Build Coastguard Worker x = linear1->forward(x).clamp_min(0);
308*da0073e9SAndroid Build Coastguard Worker x = batchnorm1->forward(x);
309*da0073e9SAndroid Build Coastguard Worker x = linear2->forward(x);
310*da0073e9SAndroid Build Coastguard Worker x = torch::log_softmax(x, 1);
311*da0073e9SAndroid Build Coastguard Worker return x;
312*da0073e9SAndroid Build Coastguard Worker };
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker auto optimizer = torch::optim::SGD(
315*da0073e9SAndroid Build Coastguard Worker model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(test_mnist(
318*da0073e9SAndroid Build Coastguard Worker 32, // batch_size
319*da0073e9SAndroid Build Coastguard Worker 3, // number_of_epochs
320*da0073e9SAndroid Build Coastguard Worker true, // with_cuda
321*da0073e9SAndroid Build Coastguard Worker model,
322*da0073e9SAndroid Build Coastguard Worker forward,
323*da0073e9SAndroid Build Coastguard Worker optimizer));
324*da0073e9SAndroid Build Coastguard Worker }
325