xref: /aosp_15_r20/external/pytorch/test/cpp/api/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/init_baseline.h>
7*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker #include <functional>
10*da0073e9SAndroid Build Coastguard Worker #include <vector>
11*da0073e9SAndroid Build Coastguard Worker 
check_exact_values(const std::vector<torch::Tensor> & parameters,const std::vector<std::vector<torch::Tensor>> & expected_parameters)12*da0073e9SAndroid Build Coastguard Worker void check_exact_values(
13*da0073e9SAndroid Build Coastguard Worker     const std::vector<torch::Tensor>& parameters,
14*da0073e9SAndroid Build Coastguard Worker     const std::vector<std::vector<torch::Tensor>>& expected_parameters) {
15*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(parameters.size(), expected_parameters.size());
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(parameters.size())) {
18*da0073e9SAndroid Build Coastguard Worker     auto layerParameters = parameters[i];
19*da0073e9SAndroid Build Coastguard Worker     auto expectedLayerParameters = expected_parameters[i];
20*da0073e9SAndroid Build Coastguard Worker 
21*da0073e9SAndroid Build Coastguard Worker     if (static_cast<size_t>(layerParameters.size(0)) !=
22*da0073e9SAndroid Build Coastguard Worker         expectedLayerParameters.size()) {
23*da0073e9SAndroid Build Coastguard Worker       std::cout << "layer #" << i
24*da0073e9SAndroid Build Coastguard Worker                 << " layerParameters size: " << layerParameters.size(0)
25*da0073e9SAndroid Build Coastguard Worker                 << " != "
26*da0073e9SAndroid Build Coastguard Worker                 << " expectedLayerParameters size: "
27*da0073e9SAndroid Build Coastguard Worker                 << expectedLayerParameters.size() << std::endl;
28*da0073e9SAndroid Build Coastguard Worker       ASSERT_TRUE(false);
29*da0073e9SAndroid Build Coastguard Worker     }
30*da0073e9SAndroid Build Coastguard Worker 
31*da0073e9SAndroid Build Coastguard Worker     for (const auto p : c10::irange(layerParameters.size(0))) {
32*da0073e9SAndroid Build Coastguard Worker       // Always compare using double dtype, regardless of the original dtype of
33*da0073e9SAndroid Build Coastguard Worker       // the tensors
34*da0073e9SAndroid Build Coastguard Worker       auto tensor = layerParameters[p].to(torch::kFloat64);
35*da0073e9SAndroid Build Coastguard Worker       auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64);
36*da0073e9SAndroid Build Coastguard Worker 
37*da0073e9SAndroid Build Coastguard Worker       if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
38*da0073e9SAndroid Build Coastguard Worker         std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor
39*da0073e9SAndroid Build Coastguard Worker                   << " (parameter " << p << ")" << std::endl;
40*da0073e9SAndroid Build Coastguard Worker         ASSERT_TRUE(false);
41*da0073e9SAndroid Build Coastguard Worker       }
42*da0073e9SAndroid Build Coastguard Worker     }
43*da0073e9SAndroid Build Coastguard Worker   }
44*da0073e9SAndroid Build Coastguard Worker }
45*da0073e9SAndroid Build Coastguard Worker 
check_initializer_against_baseline(std::function<void (torch::Tensor)> initializer,std::vector<std::vector<torch::Tensor>> expected)46*da0073e9SAndroid Build Coastguard Worker void check_initializer_against_baseline(
47*da0073e9SAndroid Build Coastguard Worker     std::function<void(torch::Tensor)> initializer,
48*da0073e9SAndroid Build Coastguard Worker     std::vector<std::vector<torch::Tensor>> expected) {
49*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
50*da0073e9SAndroid Build Coastguard Worker 
51*da0073e9SAndroid Build Coastguard Worker   auto layer1 = torch::nn::Linear(7, 15);
52*da0073e9SAndroid Build Coastguard Worker   initializer(layer1->weight);
53*da0073e9SAndroid Build Coastguard Worker   layer1->to(torch::kFloat64);
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   auto layer2 = torch::nn::Linear(15, 15);
56*da0073e9SAndroid Build Coastguard Worker   initializer(layer2->weight);
57*da0073e9SAndroid Build Coastguard Worker   layer2->to(torch::kFloat64);
58*da0073e9SAndroid Build Coastguard Worker 
59*da0073e9SAndroid Build Coastguard Worker   auto layer3 = torch::nn::Linear(15, 2);
60*da0073e9SAndroid Build Coastguard Worker   initializer(layer3->weight);
61*da0073e9SAndroid Build Coastguard Worker   layer3->to(torch::kFloat64);
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker   auto parameters = std::vector<torch::Tensor>{
64*da0073e9SAndroid Build Coastguard Worker       layer1->weight,
65*da0073e9SAndroid Build Coastguard Worker       layer2->weight,
66*da0073e9SAndroid Build Coastguard Worker       layer3->weight,
67*da0073e9SAndroid Build Coastguard Worker   };
68*da0073e9SAndroid Build Coastguard Worker 
69*da0073e9SAndroid Build Coastguard Worker   check_exact_values(parameters, expected);
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,ProducesPyTorchValues_XavierUniform)72*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, ProducesPyTorchValues_XavierUniform) {
73*da0073e9SAndroid Build Coastguard Worker   auto expected = expected_parameters::Xavier_Uniform();
74*da0073e9SAndroid Build Coastguard Worker   auto initializer = [](torch::Tensor tensor) {
75*da0073e9SAndroid Build Coastguard Worker     torch::nn::init::xavier_uniform_(tensor);
76*da0073e9SAndroid Build Coastguard Worker   };
77*da0073e9SAndroid Build Coastguard Worker   check_initializer_against_baseline(initializer, expected);
78*da0073e9SAndroid Build Coastguard Worker }
79*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,ProducesPyTorchValues_XavierNormal)80*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, ProducesPyTorchValues_XavierNormal) {
81*da0073e9SAndroid Build Coastguard Worker   auto expected = expected_parameters::Xavier_Normal();
82*da0073e9SAndroid Build Coastguard Worker   auto initializer = [](torch::Tensor tensor) {
83*da0073e9SAndroid Build Coastguard Worker     torch::nn::init::xavier_normal_(tensor);
84*da0073e9SAndroid Build Coastguard Worker   };
85*da0073e9SAndroid Build Coastguard Worker   check_initializer_against_baseline(initializer, expected);
86*da0073e9SAndroid Build Coastguard Worker }
87*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,ProducesPyTorchValues_KaimingNormal)88*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, ProducesPyTorchValues_KaimingNormal) {
89*da0073e9SAndroid Build Coastguard Worker   auto expected = expected_parameters::Kaiming_Normal();
90*da0073e9SAndroid Build Coastguard Worker   auto initializer = [](torch::Tensor tensor) {
91*da0073e9SAndroid Build Coastguard Worker     torch::nn::init::kaiming_normal_(tensor);
92*da0073e9SAndroid Build Coastguard Worker   };
93*da0073e9SAndroid Build Coastguard Worker   check_initializer_against_baseline(initializer, expected);
94*da0073e9SAndroid Build Coastguard Worker }
95*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,ProducesPyTorchValues_KaimingUniform)96*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, ProducesPyTorchValues_KaimingUniform) {
97*da0073e9SAndroid Build Coastguard Worker   auto expected = expected_parameters::Kaiming_Uniform();
98*da0073e9SAndroid Build Coastguard Worker   auto initializer = [](torch::Tensor tensor) {
99*da0073e9SAndroid Build Coastguard Worker     torch::nn::init::kaiming_uniform_(tensor);
100*da0073e9SAndroid Build Coastguard Worker   };
101*da0073e9SAndroid Build Coastguard Worker   check_initializer_against_baseline(initializer, expected);
102*da0073e9SAndroid Build Coastguard Worker }
103*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,CanInitializeTensorThatRequiresGrad)104*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
105*da0073e9SAndroid Build Coastguard Worker   auto tensor = torch::empty({3, 4}, torch::requires_grad());
106*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
107*da0073e9SAndroid Build Coastguard Worker       tensor.fill_(1),
108*da0073e9SAndroid Build Coastguard Worker       "a leaf Variable that requires grad "
109*da0073e9SAndroid Build Coastguard Worker       "is being used in an in-place operation");
110*da0073e9SAndroid Build Coastguard Worker   ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
111*da0073e9SAndroid Build Coastguard Worker }
112*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,CalculateGainWithTanh)113*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, CalculateGainWithTanh) {
114*da0073e9SAndroid Build Coastguard Worker   double gain = torch::nn::init::calculate_gain(torch::kTanh);
115*da0073e9SAndroid Build Coastguard Worker   ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
116*da0073e9SAndroid Build Coastguard Worker }
117*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,CalculateGainWithRelu)118*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, CalculateGainWithRelu) {
119*da0073e9SAndroid Build Coastguard Worker   double gain = torch::nn::init::calculate_gain(torch::kReLU);
120*da0073e9SAndroid Build Coastguard Worker   ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
121*da0073e9SAndroid Build Coastguard Worker }
122*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,CalculateGainWithLeakyRelu)123*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, CalculateGainWithLeakyRelu) {
124*da0073e9SAndroid Build Coastguard Worker   double gain = torch::nn::init::calculate_gain(torch::kLeakyReLU);
125*da0073e9SAndroid Build Coastguard Worker   ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
126*da0073e9SAndroid Build Coastguard Worker }
127*da0073e9SAndroid Build Coastguard Worker 
TEST(InitTest,CanInitializeCnnWithOrthogonal)128*da0073e9SAndroid Build Coastguard Worker TEST(InitTest, CanInitializeCnnWithOrthogonal) {
129*da0073e9SAndroid Build Coastguard Worker   torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
130*da0073e9SAndroid Build Coastguard Worker   torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
131*da0073e9SAndroid Build Coastguard Worker }
132