xref: /aosp_15_r20/external/pytorch/test/cpp/api/integration.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker #include <cmath>
9*da0073e9SAndroid Build Coastguard Worker #include <cstdlib>
10*da0073e9SAndroid Build Coastguard Worker #include <random>
11*da0073e9SAndroid Build Coastguard Worker 
12*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
13*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
14*da0073e9SAndroid Build Coastguard Worker 
15*da0073e9SAndroid Build Coastguard Worker const double kPi = 3.1415926535898;
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker class CartPole {
18*da0073e9SAndroid Build Coastguard Worker   // Translated from openai/gym's cartpole.py
19*da0073e9SAndroid Build Coastguard Worker  public:
20*da0073e9SAndroid Build Coastguard Worker   double gravity = 9.8;
21*da0073e9SAndroid Build Coastguard Worker   double masscart = 1.0;
22*da0073e9SAndroid Build Coastguard Worker   double masspole = 0.1;
23*da0073e9SAndroid Build Coastguard Worker   double total_mass = (masspole + masscart);
24*da0073e9SAndroid Build Coastguard Worker   double length = 0.5; // actually half the pole's length;
25*da0073e9SAndroid Build Coastguard Worker   double polemass_length = (masspole * length);
26*da0073e9SAndroid Build Coastguard Worker   double force_mag = 10.0;
27*da0073e9SAndroid Build Coastguard Worker   double tau = 0.02; // seconds between state updates;
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker   // Angle at which to fail the episode
30*da0073e9SAndroid Build Coastguard Worker   double theta_threshold_radians = 12 * 2 * kPi / 360;
31*da0073e9SAndroid Build Coastguard Worker   double x_threshold = 2.4;
32*da0073e9SAndroid Build Coastguard Worker   int steps_beyond_done = -1;
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker   torch::Tensor state;
35*da0073e9SAndroid Build Coastguard Worker   double reward;
36*da0073e9SAndroid Build Coastguard Worker   bool done;
37*da0073e9SAndroid Build Coastguard Worker   int step_ = 0;
38*da0073e9SAndroid Build Coastguard Worker 
getState()39*da0073e9SAndroid Build Coastguard Worker   torch::Tensor getState() {
40*da0073e9SAndroid Build Coastguard Worker     return state;
41*da0073e9SAndroid Build Coastguard Worker   }
42*da0073e9SAndroid Build Coastguard Worker 
getReward()43*da0073e9SAndroid Build Coastguard Worker   double getReward() {
44*da0073e9SAndroid Build Coastguard Worker     return reward;
45*da0073e9SAndroid Build Coastguard Worker   }
46*da0073e9SAndroid Build Coastguard Worker 
isDone()47*da0073e9SAndroid Build Coastguard Worker   double isDone() {
48*da0073e9SAndroid Build Coastguard Worker     return done;
49*da0073e9SAndroid Build Coastguard Worker   }
50*da0073e9SAndroid Build Coastguard Worker 
reset()51*da0073e9SAndroid Build Coastguard Worker   void reset() {
52*da0073e9SAndroid Build Coastguard Worker     state = torch::empty({4}).uniform_(-0.05, 0.05);
53*da0073e9SAndroid Build Coastguard Worker     steps_beyond_done = -1;
54*da0073e9SAndroid Build Coastguard Worker     step_ = 0;
55*da0073e9SAndroid Build Coastguard Worker   }
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CartPole()58*da0073e9SAndroid Build Coastguard Worker   CartPole() {
59*da0073e9SAndroid Build Coastguard Worker     reset();
60*da0073e9SAndroid Build Coastguard Worker   }
61*da0073e9SAndroid Build Coastguard Worker 
step(int action)62*da0073e9SAndroid Build Coastguard Worker   void step(int action) {
63*da0073e9SAndroid Build Coastguard Worker     auto x = state[0].item<float>();
64*da0073e9SAndroid Build Coastguard Worker     auto x_dot = state[1].item<float>();
65*da0073e9SAndroid Build Coastguard Worker     auto theta = state[2].item<float>();
66*da0073e9SAndroid Build Coastguard Worker     auto theta_dot = state[3].item<float>();
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker     auto force = (action == 1) ? force_mag : -force_mag;
69*da0073e9SAndroid Build Coastguard Worker     auto costheta = std::cos(theta);
70*da0073e9SAndroid Build Coastguard Worker     auto sintheta = std::sin(theta);
71*da0073e9SAndroid Build Coastguard Worker     auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) /
72*da0073e9SAndroid Build Coastguard Worker         total_mass;
73*da0073e9SAndroid Build Coastguard Worker     auto thetaacc = (gravity * sintheta - costheta * temp) /
74*da0073e9SAndroid Build Coastguard Worker         (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass));
75*da0073e9SAndroid Build Coastguard Worker     auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
76*da0073e9SAndroid Build Coastguard Worker 
77*da0073e9SAndroid Build Coastguard Worker     x = x + tau * x_dot;
78*da0073e9SAndroid Build Coastguard Worker     x_dot = x_dot + tau * xacc;
79*da0073e9SAndroid Build Coastguard Worker     theta = theta + tau * theta_dot;
80*da0073e9SAndroid Build Coastguard Worker     theta_dot = theta_dot + tau * thetaacc;
81*da0073e9SAndroid Build Coastguard Worker     state = torch::tensor({x, x_dot, theta, theta_dot});
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker     done = x < -x_threshold || x > x_threshold ||
84*da0073e9SAndroid Build Coastguard Worker         theta < -theta_threshold_radians || theta > theta_threshold_radians ||
85*da0073e9SAndroid Build Coastguard Worker         step_ > 200;
86*da0073e9SAndroid Build Coastguard Worker 
87*da0073e9SAndroid Build Coastguard Worker     if (!done) {
88*da0073e9SAndroid Build Coastguard Worker       reward = 1.0;
89*da0073e9SAndroid Build Coastguard Worker     } else if (steps_beyond_done == -1) {
90*da0073e9SAndroid Build Coastguard Worker       // Pole just fell!
91*da0073e9SAndroid Build Coastguard Worker       steps_beyond_done = 0;
92*da0073e9SAndroid Build Coastguard Worker       reward = 0;
93*da0073e9SAndroid Build Coastguard Worker     } else {
94*da0073e9SAndroid Build Coastguard Worker       if (steps_beyond_done == 0) {
95*da0073e9SAndroid Build Coastguard Worker         AT_ASSERT(false); // Can't do this
96*da0073e9SAndroid Build Coastguard Worker       }
97*da0073e9SAndroid Build Coastguard Worker     }
98*da0073e9SAndroid Build Coastguard Worker     step_++;
99*da0073e9SAndroid Build Coastguard Worker   }
100*da0073e9SAndroid Build Coastguard Worker };
101*da0073e9SAndroid Build Coastguard Worker 
102*da0073e9SAndroid Build Coastguard Worker template <typename M, typename F, typename O>
test_mnist(size_t batch_size,size_t number_of_epochs,bool with_cuda,M && model,F && forward_op,O && optimizer)103*da0073e9SAndroid Build Coastguard Worker bool test_mnist(
104*da0073e9SAndroid Build Coastguard Worker     size_t batch_size,
105*da0073e9SAndroid Build Coastguard Worker     size_t number_of_epochs,
106*da0073e9SAndroid Build Coastguard Worker     bool with_cuda,
107*da0073e9SAndroid Build Coastguard Worker     M&& model,
108*da0073e9SAndroid Build Coastguard Worker     F&& forward_op,
109*da0073e9SAndroid Build Coastguard Worker     O&& optimizer) {
110*da0073e9SAndroid Build Coastguard Worker   std::string mnist_path = "mnist";
111*da0073e9SAndroid Build Coastguard Worker   if (const char* user_mnist_path = getenv("TORCH_CPP_TEST_MNIST_PATH")) {
112*da0073e9SAndroid Build Coastguard Worker     mnist_path = user_mnist_path;
113*da0073e9SAndroid Build Coastguard Worker   }
114*da0073e9SAndroid Build Coastguard Worker 
115*da0073e9SAndroid Build Coastguard Worker   auto train_dataset =
116*da0073e9SAndroid Build Coastguard Worker       torch::data::datasets::MNIST(
117*da0073e9SAndroid Build Coastguard Worker           mnist_path, torch::data::datasets::MNIST::Mode::kTrain)
118*da0073e9SAndroid Build Coastguard Worker           .map(torch::data::transforms::Stack<>());
119*da0073e9SAndroid Build Coastguard Worker 
120*da0073e9SAndroid Build Coastguard Worker   auto data_loader =
121*da0073e9SAndroid Build Coastguard Worker       torch::data::make_data_loader(std::move(train_dataset), batch_size);
122*da0073e9SAndroid Build Coastguard Worker 
123*da0073e9SAndroid Build Coastguard Worker   torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU);
124*da0073e9SAndroid Build Coastguard Worker   model->to(device);
125*da0073e9SAndroid Build Coastguard Worker 
126*da0073e9SAndroid Build Coastguard Worker   for (const auto epoch : c10::irange(number_of_epochs)) {
127*da0073e9SAndroid Build Coastguard Worker     (void)epoch; // Suppress unused variable warning
128*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(performance-for-range-copy)
129*da0073e9SAndroid Build Coastguard Worker     for (torch::data::Example<> batch : *data_loader) {
130*da0073e9SAndroid Build Coastguard Worker       auto data = batch.data.to(device);
131*da0073e9SAndroid Build Coastguard Worker       auto targets = batch.target.to(device);
132*da0073e9SAndroid Build Coastguard Worker       torch::Tensor prediction = forward_op(std::move(data));
133*da0073e9SAndroid Build Coastguard Worker       // NOLINTNEXTLINE(performance-move-const-arg)
134*da0073e9SAndroid Build Coastguard Worker       torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
135*da0073e9SAndroid Build Coastguard Worker       AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
136*da0073e9SAndroid Build Coastguard Worker       optimizer.zero_grad();
137*da0073e9SAndroid Build Coastguard Worker       loss.backward();
138*da0073e9SAndroid Build Coastguard Worker       optimizer.step();
139*da0073e9SAndroid Build Coastguard Worker     }
140*da0073e9SAndroid Build Coastguard Worker   }
141*da0073e9SAndroid Build Coastguard Worker 
142*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard guard;
143*da0073e9SAndroid Build Coastguard Worker   torch::data::datasets::MNIST test_dataset(
144*da0073e9SAndroid Build Coastguard Worker       mnist_path, torch::data::datasets::MNIST::Mode::kTest);
145*da0073e9SAndroid Build Coastguard Worker   auto images = test_dataset.images().to(device),
146*da0073e9SAndroid Build Coastguard Worker        targets = test_dataset.targets().to(device);
147*da0073e9SAndroid Build Coastguard Worker 
148*da0073e9SAndroid Build Coastguard Worker   auto result = std::get<1>(forward_op(images).max(/*dim=*/1));
149*da0073e9SAndroid Build Coastguard Worker   torch::Tensor correct = (result == targets).to(torch::kFloat32);
150*da0073e9SAndroid Build Coastguard Worker   return correct.sum().item<float>() > (test_dataset.size().value() * 0.8);
151*da0073e9SAndroid Build Coastguard Worker }
152*da0073e9SAndroid Build Coastguard Worker 
153*da0073e9SAndroid Build Coastguard Worker struct IntegrationTest : torch::test::SeedingFixture {};
154*da0073e9SAndroid Build Coastguard Worker 
TEST_F(IntegrationTest,CartPole)155*da0073e9SAndroid Build Coastguard Worker TEST_F(IntegrationTest, CartPole) {
156*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
157*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<SimpleContainer>();
158*da0073e9SAndroid Build Coastguard Worker   auto linear = model->add(Linear(4, 128), "linear");
159*da0073e9SAndroid Build Coastguard Worker   auto policyHead = model->add(Linear(128, 2), "policy");
160*da0073e9SAndroid Build Coastguard Worker   auto valueHead = model->add(Linear(128, 1), "action");
161*da0073e9SAndroid Build Coastguard Worker   auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
162*da0073e9SAndroid Build Coastguard Worker 
163*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> saved_log_probs;
164*da0073e9SAndroid Build Coastguard Worker   std::vector<torch::Tensor> saved_values;
165*da0073e9SAndroid Build Coastguard Worker   std::vector<float> rewards;
166*da0073e9SAndroid Build Coastguard Worker 
167*da0073e9SAndroid Build Coastguard Worker   auto forward = [&](torch::Tensor inp) {
168*da0073e9SAndroid Build Coastguard Worker     auto x = linear->forward(inp).clamp_min(0);
169*da0073e9SAndroid Build Coastguard Worker     torch::Tensor actions = policyHead->forward(x);
170*da0073e9SAndroid Build Coastguard Worker     torch::Tensor value = valueHead->forward(x);
171*da0073e9SAndroid Build Coastguard Worker     return std::make_tuple(torch::softmax(actions, -1), value);
172*da0073e9SAndroid Build Coastguard Worker   };
173*da0073e9SAndroid Build Coastguard Worker 
174*da0073e9SAndroid Build Coastguard Worker   auto selectAction = [&](torch::Tensor state) {
175*da0073e9SAndroid Build Coastguard Worker     // Only work on single state right now, change index to gather for batch
176*da0073e9SAndroid Build Coastguard Worker     auto out = forward(state);
177*da0073e9SAndroid Build Coastguard Worker     auto probs = torch::Tensor(std::get<0>(out));
178*da0073e9SAndroid Build Coastguard Worker     auto value = torch::Tensor(std::get<1>(out));
179*da0073e9SAndroid Build Coastguard Worker     auto action = probs.multinomial(1)[0].item<int32_t>();
180*da0073e9SAndroid Build Coastguard Worker     // Compute the log prob of a multinomial distribution.
181*da0073e9SAndroid Build Coastguard Worker     // This should probably be actually implemented in autogradpp...
182*da0073e9SAndroid Build Coastguard Worker     auto p = probs / probs.sum(-1, true);
183*da0073e9SAndroid Build Coastguard Worker     auto log_prob = p[action].log();
184*da0073e9SAndroid Build Coastguard Worker     saved_log_probs.emplace_back(log_prob);
185*da0073e9SAndroid Build Coastguard Worker     saved_values.push_back(value);
186*da0073e9SAndroid Build Coastguard Worker     return action;
187*da0073e9SAndroid Build Coastguard Worker   };
188*da0073e9SAndroid Build Coastguard Worker 
189*da0073e9SAndroid Build Coastguard Worker   auto finishEpisode = [&] {
190*da0073e9SAndroid Build Coastguard Worker     auto R = 0.;
191*da0073e9SAndroid Build Coastguard Worker     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
192*da0073e9SAndroid Build Coastguard Worker     for (int i = rewards.size() - 1; i >= 0; i--) {
193*da0073e9SAndroid Build Coastguard Worker       R = rewards[i] + 0.99 * R;
194*da0073e9SAndroid Build Coastguard Worker       rewards[i] = R;
195*da0073e9SAndroid Build Coastguard Worker     }
196*da0073e9SAndroid Build Coastguard Worker     auto r_t = torch::from_blob(
197*da0073e9SAndroid Build Coastguard Worker         rewards.data(), {static_cast<int64_t>(rewards.size())});
198*da0073e9SAndroid Build Coastguard Worker     r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
199*da0073e9SAndroid Build Coastguard Worker 
200*da0073e9SAndroid Build Coastguard Worker     std::vector<torch::Tensor> policy_loss;
201*da0073e9SAndroid Build Coastguard Worker     std::vector<torch::Tensor> value_loss;
202*da0073e9SAndroid Build Coastguard Worker     for (const auto i : c10::irange(0U, saved_log_probs.size())) {
203*da0073e9SAndroid Build Coastguard Worker       auto advantage = r_t[i] - saved_values[i].item<float>();
204*da0073e9SAndroid Build Coastguard Worker       policy_loss.push_back(-advantage * saved_log_probs[i]);
205*da0073e9SAndroid Build Coastguard Worker       value_loss.push_back(
206*da0073e9SAndroid Build Coastguard Worker           torch::smooth_l1_loss(saved_values[i], torch::ones(1) * r_t[i]));
207*da0073e9SAndroid Build Coastguard Worker     }
208*da0073e9SAndroid Build Coastguard Worker 
209*da0073e9SAndroid Build Coastguard Worker     auto loss =
210*da0073e9SAndroid Build Coastguard Worker         torch::stack(policy_loss).sum() + torch::stack(value_loss).sum();
211*da0073e9SAndroid Build Coastguard Worker 
212*da0073e9SAndroid Build Coastguard Worker     optimizer.zero_grad();
213*da0073e9SAndroid Build Coastguard Worker     loss.backward();
214*da0073e9SAndroid Build Coastguard Worker     optimizer.step();
215*da0073e9SAndroid Build Coastguard Worker 
216*da0073e9SAndroid Build Coastguard Worker     rewards.clear();
217*da0073e9SAndroid Build Coastguard Worker     saved_log_probs.clear();
218*da0073e9SAndroid Build Coastguard Worker     saved_values.clear();
219*da0073e9SAndroid Build Coastguard Worker   };
220*da0073e9SAndroid Build Coastguard Worker 
221*da0073e9SAndroid Build Coastguard Worker   auto env = CartPole();
222*da0073e9SAndroid Build Coastguard Worker   double running_reward = 10.0;
223*da0073e9SAndroid Build Coastguard Worker   for (size_t episode = 0;; episode++) {
224*da0073e9SAndroid Build Coastguard Worker     env.reset();
225*da0073e9SAndroid Build Coastguard Worker     auto state = env.getState();
226*da0073e9SAndroid Build Coastguard Worker     int t = 0;
227*da0073e9SAndroid Build Coastguard Worker     for (; t < 10000; t++) {
228*da0073e9SAndroid Build Coastguard Worker       auto action = selectAction(state);
229*da0073e9SAndroid Build Coastguard Worker       env.step(action);
230*da0073e9SAndroid Build Coastguard Worker       state = env.getState();
231*da0073e9SAndroid Build Coastguard Worker       auto reward = env.getReward();
232*da0073e9SAndroid Build Coastguard Worker       auto done = env.isDone();
233*da0073e9SAndroid Build Coastguard Worker 
234*da0073e9SAndroid Build Coastguard Worker       rewards.push_back(reward);
235*da0073e9SAndroid Build Coastguard Worker       if (done)
236*da0073e9SAndroid Build Coastguard Worker         break;
237*da0073e9SAndroid Build Coastguard Worker     }
238*da0073e9SAndroid Build Coastguard Worker 
239*da0073e9SAndroid Build Coastguard Worker     running_reward = running_reward * 0.99 + t * 0.01;
240*da0073e9SAndroid Build Coastguard Worker     finishEpisode();
241*da0073e9SAndroid Build Coastguard Worker     /*
242*da0073e9SAndroid Build Coastguard Worker     if (episode % 10 == 0) {
243*da0073e9SAndroid Build Coastguard Worker       printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
244*da0073e9SAndroid Build Coastguard Worker               episode, t, running_reward);
245*da0073e9SAndroid Build Coastguard Worker     }
246*da0073e9SAndroid Build Coastguard Worker     */
247*da0073e9SAndroid Build Coastguard Worker     if (running_reward > 150) {
248*da0073e9SAndroid Build Coastguard Worker       break;
249*da0073e9SAndroid Build Coastguard Worker     }
250*da0073e9SAndroid Build Coastguard Worker     ASSERT_LT(episode, 3000);
251*da0073e9SAndroid Build Coastguard Worker   }
252*da0073e9SAndroid Build Coastguard Worker }
253*da0073e9SAndroid Build Coastguard Worker 
TEST_F(IntegrationTest,MNIST_CUDA)254*da0073e9SAndroid Build Coastguard Worker TEST_F(IntegrationTest, MNIST_CUDA) {
255*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
256*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<SimpleContainer>();
257*da0073e9SAndroid Build Coastguard Worker   auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
258*da0073e9SAndroid Build Coastguard Worker   auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
259*da0073e9SAndroid Build Coastguard Worker   auto drop = Dropout(0.3);
260*da0073e9SAndroid Build Coastguard Worker   auto drop2d = Dropout2d(0.3);
261*da0073e9SAndroid Build Coastguard Worker   auto linear1 = model->add(Linear(320, 50), "linear1");
262*da0073e9SAndroid Build Coastguard Worker   auto linear2 = model->add(Linear(50, 10), "linear2");
263*da0073e9SAndroid Build Coastguard Worker 
264*da0073e9SAndroid Build Coastguard Worker   auto forward = [&](torch::Tensor x) {
265*da0073e9SAndroid Build Coastguard Worker     x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
266*da0073e9SAndroid Build Coastguard Worker     x = conv2->forward(x);
267*da0073e9SAndroid Build Coastguard Worker     x = drop2d->forward(x);
268*da0073e9SAndroid Build Coastguard Worker     x = torch::max_pool2d(x, {2, 2}).relu();
269*da0073e9SAndroid Build Coastguard Worker 
270*da0073e9SAndroid Build Coastguard Worker     x = x.view({-1, 320});
271*da0073e9SAndroid Build Coastguard Worker     x = linear1->forward(x).clamp_min(0);
272*da0073e9SAndroid Build Coastguard Worker     x = drop->forward(x);
273*da0073e9SAndroid Build Coastguard Worker     x = linear2->forward(x);
274*da0073e9SAndroid Build Coastguard Worker     x = torch::log_softmax(x, 1);
275*da0073e9SAndroid Build Coastguard Worker     return x;
276*da0073e9SAndroid Build Coastguard Worker   };
277*da0073e9SAndroid Build Coastguard Worker 
278*da0073e9SAndroid Build Coastguard Worker   auto optimizer = torch::optim::SGD(
279*da0073e9SAndroid Build Coastguard Worker       model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
280*da0073e9SAndroid Build Coastguard Worker 
281*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_mnist(
282*da0073e9SAndroid Build Coastguard Worker       32, // batch_size
283*da0073e9SAndroid Build Coastguard Worker       3, // number_of_epochs
284*da0073e9SAndroid Build Coastguard Worker       true, // with_cuda
285*da0073e9SAndroid Build Coastguard Worker       model,
286*da0073e9SAndroid Build Coastguard Worker       forward,
287*da0073e9SAndroid Build Coastguard Worker       optimizer));
288*da0073e9SAndroid Build Coastguard Worker }
289*da0073e9SAndroid Build Coastguard Worker 
TEST_F(IntegrationTest,MNISTBatchNorm_CUDA)290*da0073e9SAndroid Build Coastguard Worker TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
291*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
292*da0073e9SAndroid Build Coastguard Worker   auto model = std::make_shared<SimpleContainer>();
293*da0073e9SAndroid Build Coastguard Worker   auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
294*da0073e9SAndroid Build Coastguard Worker   auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
295*da0073e9SAndroid Build Coastguard Worker   auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
296*da0073e9SAndroid Build Coastguard Worker   auto linear1 = model->add(Linear(320, 50), "linear1");
297*da0073e9SAndroid Build Coastguard Worker   auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
298*da0073e9SAndroid Build Coastguard Worker   auto linear2 = model->add(Linear(50, 10), "linear2");
299*da0073e9SAndroid Build Coastguard Worker 
300*da0073e9SAndroid Build Coastguard Worker   auto forward = [&](torch::Tensor x) {
301*da0073e9SAndroid Build Coastguard Worker     x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
302*da0073e9SAndroid Build Coastguard Worker     x = batchnorm2d->forward(x);
303*da0073e9SAndroid Build Coastguard Worker     x = conv2->forward(x);
304*da0073e9SAndroid Build Coastguard Worker     x = torch::max_pool2d(x, {2, 2}).relu();
305*da0073e9SAndroid Build Coastguard Worker 
306*da0073e9SAndroid Build Coastguard Worker     x = x.view({-1, 320});
307*da0073e9SAndroid Build Coastguard Worker     x = linear1->forward(x).clamp_min(0);
308*da0073e9SAndroid Build Coastguard Worker     x = batchnorm1->forward(x);
309*da0073e9SAndroid Build Coastguard Worker     x = linear2->forward(x);
310*da0073e9SAndroid Build Coastguard Worker     x = torch::log_softmax(x, 1);
311*da0073e9SAndroid Build Coastguard Worker     return x;
312*da0073e9SAndroid Build Coastguard Worker   };
313*da0073e9SAndroid Build Coastguard Worker 
314*da0073e9SAndroid Build Coastguard Worker   auto optimizer = torch::optim::SGD(
315*da0073e9SAndroid Build Coastguard Worker       model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
316*da0073e9SAndroid Build Coastguard Worker 
317*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(test_mnist(
318*da0073e9SAndroid Build Coastguard Worker       32, // batch_size
319*da0073e9SAndroid Build Coastguard Worker       3, // number_of_epochs
320*da0073e9SAndroid Build Coastguard Worker       true, // with_cuda
321*da0073e9SAndroid Build Coastguard Worker       model,
322*da0073e9SAndroid Build Coastguard Worker       forward,
323*da0073e9SAndroid Build Coastguard Worker       optimizer));
324*da0073e9SAndroid Build Coastguard Worker }
325