xref: /aosp_15_r20/external/pytorch/test/cpp/api/integration.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 
8 #include <cmath>
9 #include <cstdlib>
10 #include <random>
11 
12 using namespace torch::nn;
13 using namespace torch::test;
14 
15 const double kPi = 3.1415926535898;
16 
17 class CartPole {
18   // Translated from openai/gym's cartpole.py
19  public:
20   double gravity = 9.8;
21   double masscart = 1.0;
22   double masspole = 0.1;
23   double total_mass = (masspole + masscart);
24   double length = 0.5; // actually half the pole's length;
25   double polemass_length = (masspole * length);
26   double force_mag = 10.0;
27   double tau = 0.02; // seconds between state updates;
28 
29   // Angle at which to fail the episode
30   double theta_threshold_radians = 12 * 2 * kPi / 360;
31   double x_threshold = 2.4;
32   int steps_beyond_done = -1;
33 
34   torch::Tensor state;
35   double reward;
36   bool done;
37   int step_ = 0;
38 
getState()39   torch::Tensor getState() {
40     return state;
41   }
42 
getReward()43   double getReward() {
44     return reward;
45   }
46 
isDone()47   double isDone() {
48     return done;
49   }
50 
reset()51   void reset() {
52     state = torch::empty({4}).uniform_(-0.05, 0.05);
53     steps_beyond_done = -1;
54     step_ = 0;
55   }
56 
57   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CartPole()58   CartPole() {
59     reset();
60   }
61 
step(int action)62   void step(int action) {
63     auto x = state[0].item<float>();
64     auto x_dot = state[1].item<float>();
65     auto theta = state[2].item<float>();
66     auto theta_dot = state[3].item<float>();
67 
68     auto force = (action == 1) ? force_mag : -force_mag;
69     auto costheta = std::cos(theta);
70     auto sintheta = std::sin(theta);
71     auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) /
72         total_mass;
73     auto thetaacc = (gravity * sintheta - costheta * temp) /
74         (length * (4.0 / 3.0 - masspole * costheta * costheta / total_mass));
75     auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
76 
77     x = x + tau * x_dot;
78     x_dot = x_dot + tau * xacc;
79     theta = theta + tau * theta_dot;
80     theta_dot = theta_dot + tau * thetaacc;
81     state = torch::tensor({x, x_dot, theta, theta_dot});
82 
83     done = x < -x_threshold || x > x_threshold ||
84         theta < -theta_threshold_radians || theta > theta_threshold_radians ||
85         step_ > 200;
86 
87     if (!done) {
88       reward = 1.0;
89     } else if (steps_beyond_done == -1) {
90       // Pole just fell!
91       steps_beyond_done = 0;
92       reward = 0;
93     } else {
94       if (steps_beyond_done == 0) {
95         AT_ASSERT(false); // Can't do this
96       }
97     }
98     step_++;
99   }
100 };
101 
102 template <typename M, typename F, typename O>
test_mnist(size_t batch_size,size_t number_of_epochs,bool with_cuda,M && model,F && forward_op,O && optimizer)103 bool test_mnist(
104     size_t batch_size,
105     size_t number_of_epochs,
106     bool with_cuda,
107     M&& model,
108     F&& forward_op,
109     O&& optimizer) {
110   std::string mnist_path = "mnist";
111   if (const char* user_mnist_path = getenv("TORCH_CPP_TEST_MNIST_PATH")) {
112     mnist_path = user_mnist_path;
113   }
114 
115   auto train_dataset =
116       torch::data::datasets::MNIST(
117           mnist_path, torch::data::datasets::MNIST::Mode::kTrain)
118           .map(torch::data::transforms::Stack<>());
119 
120   auto data_loader =
121       torch::data::make_data_loader(std::move(train_dataset), batch_size);
122 
123   torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU);
124   model->to(device);
125 
126   for (const auto epoch : c10::irange(number_of_epochs)) {
127     (void)epoch; // Suppress unused variable warning
128     // NOLINTNEXTLINE(performance-for-range-copy)
129     for (torch::data::Example<> batch : *data_loader) {
130       auto data = batch.data.to(device);
131       auto targets = batch.target.to(device);
132       torch::Tensor prediction = forward_op(std::move(data));
133       // NOLINTNEXTLINE(performance-move-const-arg)
134       torch::Tensor loss = torch::nll_loss(prediction, std::move(targets));
135       AT_ASSERT(!torch::isnan(loss).any().item<int64_t>());
136       optimizer.zero_grad();
137       loss.backward();
138       optimizer.step();
139     }
140   }
141 
142   torch::NoGradGuard guard;
143   torch::data::datasets::MNIST test_dataset(
144       mnist_path, torch::data::datasets::MNIST::Mode::kTest);
145   auto images = test_dataset.images().to(device),
146        targets = test_dataset.targets().to(device);
147 
148   auto result = std::get<1>(forward_op(images).max(/*dim=*/1));
149   torch::Tensor correct = (result == targets).to(torch::kFloat32);
150   return correct.sum().item<float>() > (test_dataset.size().value() * 0.8);
151 }
152 
153 struct IntegrationTest : torch::test::SeedingFixture {};
154 
TEST_F(IntegrationTest,CartPole)155 TEST_F(IntegrationTest, CartPole) {
156   torch::manual_seed(0);
157   auto model = std::make_shared<SimpleContainer>();
158   auto linear = model->add(Linear(4, 128), "linear");
159   auto policyHead = model->add(Linear(128, 2), "policy");
160   auto valueHead = model->add(Linear(128, 1), "action");
161   auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
162 
163   std::vector<torch::Tensor> saved_log_probs;
164   std::vector<torch::Tensor> saved_values;
165   std::vector<float> rewards;
166 
167   auto forward = [&](torch::Tensor inp) {
168     auto x = linear->forward(inp).clamp_min(0);
169     torch::Tensor actions = policyHead->forward(x);
170     torch::Tensor value = valueHead->forward(x);
171     return std::make_tuple(torch::softmax(actions, -1), value);
172   };
173 
174   auto selectAction = [&](torch::Tensor state) {
175     // Only work on single state right now, change index to gather for batch
176     auto out = forward(state);
177     auto probs = torch::Tensor(std::get<0>(out));
178     auto value = torch::Tensor(std::get<1>(out));
179     auto action = probs.multinomial(1)[0].item<int32_t>();
180     // Compute the log prob of a multinomial distribution.
181     // This should probably be actually implemented in autogradpp...
182     auto p = probs / probs.sum(-1, true);
183     auto log_prob = p[action].log();
184     saved_log_probs.emplace_back(log_prob);
185     saved_values.push_back(value);
186     return action;
187   };
188 
189   auto finishEpisode = [&] {
190     auto R = 0.;
191     // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
192     for (int i = rewards.size() - 1; i >= 0; i--) {
193       R = rewards[i] + 0.99 * R;
194       rewards[i] = R;
195     }
196     auto r_t = torch::from_blob(
197         rewards.data(), {static_cast<int64_t>(rewards.size())});
198     r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
199 
200     std::vector<torch::Tensor> policy_loss;
201     std::vector<torch::Tensor> value_loss;
202     for (const auto i : c10::irange(0U, saved_log_probs.size())) {
203       auto advantage = r_t[i] - saved_values[i].item<float>();
204       policy_loss.push_back(-advantage * saved_log_probs[i]);
205       value_loss.push_back(
206           torch::smooth_l1_loss(saved_values[i], torch::ones(1) * r_t[i]));
207     }
208 
209     auto loss =
210         torch::stack(policy_loss).sum() + torch::stack(value_loss).sum();
211 
212     optimizer.zero_grad();
213     loss.backward();
214     optimizer.step();
215 
216     rewards.clear();
217     saved_log_probs.clear();
218     saved_values.clear();
219   };
220 
221   auto env = CartPole();
222   double running_reward = 10.0;
223   for (size_t episode = 0;; episode++) {
224     env.reset();
225     auto state = env.getState();
226     int t = 0;
227     for (; t < 10000; t++) {
228       auto action = selectAction(state);
229       env.step(action);
230       state = env.getState();
231       auto reward = env.getReward();
232       auto done = env.isDone();
233 
234       rewards.push_back(reward);
235       if (done)
236         break;
237     }
238 
239     running_reward = running_reward * 0.99 + t * 0.01;
240     finishEpisode();
241     /*
242     if (episode % 10 == 0) {
243       printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
244               episode, t, running_reward);
245     }
246     */
247     if (running_reward > 150) {
248       break;
249     }
250     ASSERT_LT(episode, 3000);
251   }
252 }
253 
TEST_F(IntegrationTest,MNIST_CUDA)254 TEST_F(IntegrationTest, MNIST_CUDA) {
255   torch::manual_seed(0);
256   auto model = std::make_shared<SimpleContainer>();
257   auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
258   auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
259   auto drop = Dropout(0.3);
260   auto drop2d = Dropout2d(0.3);
261   auto linear1 = model->add(Linear(320, 50), "linear1");
262   auto linear2 = model->add(Linear(50, 10), "linear2");
263 
264   auto forward = [&](torch::Tensor x) {
265     x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
266     x = conv2->forward(x);
267     x = drop2d->forward(x);
268     x = torch::max_pool2d(x, {2, 2}).relu();
269 
270     x = x.view({-1, 320});
271     x = linear1->forward(x).clamp_min(0);
272     x = drop->forward(x);
273     x = linear2->forward(x);
274     x = torch::log_softmax(x, 1);
275     return x;
276   };
277 
278   auto optimizer = torch::optim::SGD(
279       model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
280 
281   ASSERT_TRUE(test_mnist(
282       32, // batch_size
283       3, // number_of_epochs
284       true, // with_cuda
285       model,
286       forward,
287       optimizer));
288 }
289 
TEST_F(IntegrationTest,MNISTBatchNorm_CUDA)290 TEST_F(IntegrationTest, MNISTBatchNorm_CUDA) {
291   torch::manual_seed(0);
292   auto model = std::make_shared<SimpleContainer>();
293   auto conv1 = model->add(Conv2d(1, 10, 5), "conv1");
294   auto batchnorm2d = model->add(BatchNorm2d(10), "batchnorm2d");
295   auto conv2 = model->add(Conv2d(10, 20, 5), "conv2");
296   auto linear1 = model->add(Linear(320, 50), "linear1");
297   auto batchnorm1 = model->add(BatchNorm1d(50), "batchnorm1");
298   auto linear2 = model->add(Linear(50, 10), "linear2");
299 
300   auto forward = [&](torch::Tensor x) {
301     x = torch::max_pool2d(conv1->forward(x), {2, 2}).relu();
302     x = batchnorm2d->forward(x);
303     x = conv2->forward(x);
304     x = torch::max_pool2d(x, {2, 2}).relu();
305 
306     x = x.view({-1, 320});
307     x = linear1->forward(x).clamp_min(0);
308     x = batchnorm1->forward(x);
309     x = linear2->forward(x);
310     x = torch::log_softmax(x, 1);
311     return x;
312   };
313 
314   auto optimizer = torch::optim::SGD(
315       model->parameters(), torch::optim::SGDOptions(1e-2).momentum(0.5));
316 
317   ASSERT_TRUE(test_mnist(
318       32, // batch_size
319       3, // number_of_epochs
320       true, // with_cuda
321       model,
322       forward,
323       optimizer));
324 }
325