xref: /aosp_15_r20/external/pytorch/test/cpp/api/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/init_baseline.h>
7 #include <test/cpp/api/support.h>
8 
9 #include <functional>
10 #include <vector>
11 
check_exact_values(const std::vector<torch::Tensor> & parameters,const std::vector<std::vector<torch::Tensor>> & expected_parameters)12 void check_exact_values(
13     const std::vector<torch::Tensor>& parameters,
14     const std::vector<std::vector<torch::Tensor>>& expected_parameters) {
15   ASSERT_EQ(parameters.size(), expected_parameters.size());
16 
17   for (const auto i : c10::irange(parameters.size())) {
18     auto layerParameters = parameters[i];
19     auto expectedLayerParameters = expected_parameters[i];
20 
21     if (static_cast<size_t>(layerParameters.size(0)) !=
22         expectedLayerParameters.size()) {
23       std::cout << "layer #" << i
24                 << " layerParameters size: " << layerParameters.size(0)
25                 << " != "
26                 << " expectedLayerParameters size: "
27                 << expectedLayerParameters.size() << std::endl;
28       ASSERT_TRUE(false);
29     }
30 
31     for (const auto p : c10::irange(layerParameters.size(0))) {
32       // Always compare using double dtype, regardless of the original dtype of
33       // the tensors
34       auto tensor = layerParameters[p].to(torch::kFloat64);
35       auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64);
36 
37       if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
38         std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor
39                   << " (parameter " << p << ")" << std::endl;
40         ASSERT_TRUE(false);
41       }
42     }
43   }
44 }
45 
check_initializer_against_baseline(std::function<void (torch::Tensor)> initializer,std::vector<std::vector<torch::Tensor>> expected)46 void check_initializer_against_baseline(
47     std::function<void(torch::Tensor)> initializer,
48     std::vector<std::vector<torch::Tensor>> expected) {
49   torch::manual_seed(0);
50 
51   auto layer1 = torch::nn::Linear(7, 15);
52   initializer(layer1->weight);
53   layer1->to(torch::kFloat64);
54 
55   auto layer2 = torch::nn::Linear(15, 15);
56   initializer(layer2->weight);
57   layer2->to(torch::kFloat64);
58 
59   auto layer3 = torch::nn::Linear(15, 2);
60   initializer(layer3->weight);
61   layer3->to(torch::kFloat64);
62 
63   auto parameters = std::vector<torch::Tensor>{
64       layer1->weight,
65       layer2->weight,
66       layer3->weight,
67   };
68 
69   check_exact_values(parameters, expected);
70 }
71 
TEST(InitTest,ProducesPyTorchValues_XavierUniform)72 TEST(InitTest, ProducesPyTorchValues_XavierUniform) {
73   auto expected = expected_parameters::Xavier_Uniform();
74   auto initializer = [](torch::Tensor tensor) {
75     torch::nn::init::xavier_uniform_(tensor);
76   };
77   check_initializer_against_baseline(initializer, expected);
78 }
79 
TEST(InitTest,ProducesPyTorchValues_XavierNormal)80 TEST(InitTest, ProducesPyTorchValues_XavierNormal) {
81   auto expected = expected_parameters::Xavier_Normal();
82   auto initializer = [](torch::Tensor tensor) {
83     torch::nn::init::xavier_normal_(tensor);
84   };
85   check_initializer_against_baseline(initializer, expected);
86 }
87 
TEST(InitTest,ProducesPyTorchValues_KaimingNormal)88 TEST(InitTest, ProducesPyTorchValues_KaimingNormal) {
89   auto expected = expected_parameters::Kaiming_Normal();
90   auto initializer = [](torch::Tensor tensor) {
91     torch::nn::init::kaiming_normal_(tensor);
92   };
93   check_initializer_against_baseline(initializer, expected);
94 }
95 
TEST(InitTest,ProducesPyTorchValues_KaimingUniform)96 TEST(InitTest, ProducesPyTorchValues_KaimingUniform) {
97   auto expected = expected_parameters::Kaiming_Uniform();
98   auto initializer = [](torch::Tensor tensor) {
99     torch::nn::init::kaiming_uniform_(tensor);
100   };
101   check_initializer_against_baseline(initializer, expected);
102 }
103 
TEST(InitTest,CanInitializeTensorThatRequiresGrad)104 TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
105   auto tensor = torch::empty({3, 4}, torch::requires_grad());
106   ASSERT_THROWS_WITH(
107       tensor.fill_(1),
108       "a leaf Variable that requires grad "
109       "is being used in an in-place operation");
110   ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
111 }
112 
TEST(InitTest,CalculateGainWithTanh)113 TEST(InitTest, CalculateGainWithTanh) {
114   double gain = torch::nn::init::calculate_gain(torch::kTanh);
115   ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
116 }
117 
TEST(InitTest,CalculateGainWithRelu)118 TEST(InitTest, CalculateGainWithRelu) {
119   double gain = torch::nn::init::calculate_gain(torch::kReLU);
120   ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
121 }
122 
TEST(InitTest,CalculateGainWithLeakyRelu)123 TEST(InitTest, CalculateGainWithLeakyRelu) {
124   double gain = torch::nn::init::calculate_gain(torch::kLeakyReLU);
125   ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
126 }
127 
TEST(InitTest,CanInitializeCnnWithOrthogonal)128 TEST(InitTest, CanInitializeCnnWithOrthogonal) {
129   torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
130   torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
131 }
132