xref: /aosp_15_r20/external/pytorch/test/cpp/api/misc.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker #include <functional>
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
10*da0073e9SAndroid Build Coastguard Worker 
torch_warn_once_A()11*da0073e9SAndroid Build Coastguard Worker void torch_warn_once_A() {
12*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN_ONCE("warn once");
13*da0073e9SAndroid Build Coastguard Worker }
14*da0073e9SAndroid Build Coastguard Worker 
torch_warn_once_B()15*da0073e9SAndroid Build Coastguard Worker void torch_warn_once_B() {
16*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN_ONCE("warn something else once");
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker 
torch_warn()19*da0073e9SAndroid Build Coastguard Worker void torch_warn() {
20*da0073e9SAndroid Build Coastguard Worker   TORCH_WARN("warn multiple times");
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker 
TEST(UtilsTest,WarnOnce)23*da0073e9SAndroid Build Coastguard Worker TEST(UtilsTest, WarnOnce) {
24*da0073e9SAndroid Build Coastguard Worker   {
25*da0073e9SAndroid Build Coastguard Worker     WarningCapture warnings;
26*da0073e9SAndroid Build Coastguard Worker 
27*da0073e9SAndroid Build Coastguard Worker     torch_warn_once_A();
28*da0073e9SAndroid Build Coastguard Worker     torch_warn_once_A();
29*da0073e9SAndroid Build Coastguard Worker     torch_warn_once_B();
30*da0073e9SAndroid Build Coastguard Worker     torch_warn_once_B();
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
33*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
34*da0073e9SAndroid Build Coastguard Worker         count_substr_occurrences(warnings.str(), "warn something else once"),
35*da0073e9SAndroid Build Coastguard Worker         1);
36*da0073e9SAndroid Build Coastguard Worker   }
37*da0073e9SAndroid Build Coastguard Worker   {
38*da0073e9SAndroid Build Coastguard Worker     WarningCapture warnings;
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker     torch_warn();
41*da0073e9SAndroid Build Coastguard Worker     torch_warn();
42*da0073e9SAndroid Build Coastguard Worker     torch_warn();
43*da0073e9SAndroid Build Coastguard Worker 
44*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(
45*da0073e9SAndroid Build Coastguard Worker         count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
46*da0073e9SAndroid Build Coastguard Worker   }
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker 
TEST(NoGradTest,SetsGradModeCorrectly)49*da0073e9SAndroid Build Coastguard Worker TEST(NoGradTest, SetsGradModeCorrectly) {
50*da0073e9SAndroid Build Coastguard Worker   torch::manual_seed(0);
51*da0073e9SAndroid Build Coastguard Worker   torch::NoGradGuard guard;
52*da0073e9SAndroid Build Coastguard Worker   torch::nn::Linear model(5, 2);
53*da0073e9SAndroid Build Coastguard Worker   auto x = torch::randn({10, 5}, torch::requires_grad());
54*da0073e9SAndroid Build Coastguard Worker   auto y = model->forward(x);
55*da0073e9SAndroid Build Coastguard Worker   torch::Tensor s = y.sum();
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker   // Mimicking python API behavior:
58*da0073e9SAndroid Build Coastguard Worker   ASSERT_THROWS_WITH(
59*da0073e9SAndroid Build Coastguard Worker       s.backward(),
60*da0073e9SAndroid Build Coastguard Worker       "element 0 of tensors does not require grad and does not have a grad_fn")
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker 
63*da0073e9SAndroid Build Coastguard Worker struct AutogradTest : torch::test::SeedingFixture {
AutogradTestAutogradTest64*da0073e9SAndroid Build Coastguard Worker   AutogradTest() {
65*da0073e9SAndroid Build Coastguard Worker     x = torch::randn({3, 3}, torch::requires_grad());
66*da0073e9SAndroid Build Coastguard Worker     y = torch::randn({3, 3});
67*da0073e9SAndroid Build Coastguard Worker     z = x * y;
68*da0073e9SAndroid Build Coastguard Worker   }
69*da0073e9SAndroid Build Coastguard Worker   torch::Tensor x, y, z;
70*da0073e9SAndroid Build Coastguard Worker };
71*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AutogradTest,CanTakeDerivatives)72*da0073e9SAndroid Build Coastguard Worker TEST_F(AutogradTest, CanTakeDerivatives) {
73*da0073e9SAndroid Build Coastguard Worker   z.backward(torch::ones_like(z));
74*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.grad().allclose(y));
75*da0073e9SAndroid Build Coastguard Worker }
76*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AutogradTest,CanTakeDerivativesOfZeroDimTensors)77*da0073e9SAndroid Build Coastguard Worker TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
78*da0073e9SAndroid Build Coastguard Worker   z.sum().backward();
79*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.grad().allclose(y));
80*da0073e9SAndroid Build Coastguard Worker }
81*da0073e9SAndroid Build Coastguard Worker 
TEST_F(AutogradTest,CanPassCustomGradientInputs)82*da0073e9SAndroid Build Coastguard Worker TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83*da0073e9SAndroid Build Coastguard Worker   z.sum().backward(torch::ones({}) * 2);
84*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(x.grad().allclose(y * 2));
85*da0073e9SAndroid Build Coastguard Worker }
86*da0073e9SAndroid Build Coastguard Worker 
TEST(UtilsTest,AmbiguousOperatorDefaults)87*da0073e9SAndroid Build Coastguard Worker TEST(UtilsTest, AmbiguousOperatorDefaults) {
88*da0073e9SAndroid Build Coastguard Worker   auto tmp = at::empty({}, at::kCPU);
89*da0073e9SAndroid Build Coastguard Worker   at::_test_ambiguous_defaults(tmp);
90*da0073e9SAndroid Build Coastguard Worker   at::_test_ambiguous_defaults(tmp, 1);
91*da0073e9SAndroid Build Coastguard Worker   at::_test_ambiguous_defaults(tmp, 1, 1);
92*da0073e9SAndroid Build Coastguard Worker   at::_test_ambiguous_defaults(tmp, 2, "2");
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker 
get_first_element(c10::OptionalIntArrayRef arr)95*da0073e9SAndroid Build Coastguard Worker int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96*da0073e9SAndroid Build Coastguard Worker   return arr.value()[0];
97*da0073e9SAndroid Build Coastguard Worker }
98*da0073e9SAndroid Build Coastguard Worker 
TEST(OptionalArrayRefTest,DanglingPointerFix)99*da0073e9SAndroid Build Coastguard Worker TEST(OptionalArrayRefTest, DanglingPointerFix) {
100*da0073e9SAndroid Build Coastguard Worker   // Ensure that the converting constructor of `OptionalArrayRef` does not
101*da0073e9SAndroid Build Coastguard Worker   // create a dangling pointer when given a single value
102*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(get_first_element(300) == 300);
103*da0073e9SAndroid Build Coastguard Worker   ASSERT_TRUE(get_first_element({400}) == 400);
104*da0073e9SAndroid Build Coastguard Worker }
105