xref: /aosp_15_r20/external/pytorch/test/cpp/api/misc.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/torch.h>
4 
5 #include <test/cpp/api/support.h>
6 
7 #include <functional>
8 
9 using namespace torch::test;
10 
torch_warn_once_A()11 void torch_warn_once_A() {
12   TORCH_WARN_ONCE("warn once");
13 }
14 
torch_warn_once_B()15 void torch_warn_once_B() {
16   TORCH_WARN_ONCE("warn something else once");
17 }
18 
torch_warn()19 void torch_warn() {
20   TORCH_WARN("warn multiple times");
21 }
22 
TEST(UtilsTest,WarnOnce)23 TEST(UtilsTest, WarnOnce) {
24   {
25     WarningCapture warnings;
26 
27     torch_warn_once_A();
28     torch_warn_once_A();
29     torch_warn_once_B();
30     torch_warn_once_B();
31 
32     ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
33     ASSERT_EQ(
34         count_substr_occurrences(warnings.str(), "warn something else once"),
35         1);
36   }
37   {
38     WarningCapture warnings;
39 
40     torch_warn();
41     torch_warn();
42     torch_warn();
43 
44     ASSERT_EQ(
45         count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
46   }
47 }
48 
TEST(NoGradTest,SetsGradModeCorrectly)49 TEST(NoGradTest, SetsGradModeCorrectly) {
50   torch::manual_seed(0);
51   torch::NoGradGuard guard;
52   torch::nn::Linear model(5, 2);
53   auto x = torch::randn({10, 5}, torch::requires_grad());
54   auto y = model->forward(x);
55   torch::Tensor s = y.sum();
56 
57   // Mimicking python API behavior:
58   ASSERT_THROWS_WITH(
59       s.backward(),
60       "element 0 of tensors does not require grad and does not have a grad_fn")
61 }
62 
63 struct AutogradTest : torch::test::SeedingFixture {
AutogradTestAutogradTest64   AutogradTest() {
65     x = torch::randn({3, 3}, torch::requires_grad());
66     y = torch::randn({3, 3});
67     z = x * y;
68   }
69   torch::Tensor x, y, z;
70 };
71 
TEST_F(AutogradTest,CanTakeDerivatives)72 TEST_F(AutogradTest, CanTakeDerivatives) {
73   z.backward(torch::ones_like(z));
74   ASSERT_TRUE(x.grad().allclose(y));
75 }
76 
TEST_F(AutogradTest,CanTakeDerivativesOfZeroDimTensors)77 TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
78   z.sum().backward();
79   ASSERT_TRUE(x.grad().allclose(y));
80 }
81 
TEST_F(AutogradTest,CanPassCustomGradientInputs)82 TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83   z.sum().backward(torch::ones({}) * 2);
84   ASSERT_TRUE(x.grad().allclose(y * 2));
85 }
86 
TEST(UtilsTest,AmbiguousOperatorDefaults)87 TEST(UtilsTest, AmbiguousOperatorDefaults) {
88   auto tmp = at::empty({}, at::kCPU);
89   at::_test_ambiguous_defaults(tmp);
90   at::_test_ambiguous_defaults(tmp, 1);
91   at::_test_ambiguous_defaults(tmp, 1, 1);
92   at::_test_ambiguous_defaults(tmp, 2, "2");
93 }
94 
get_first_element(c10::OptionalIntArrayRef arr)95 int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96   return arr.value()[0];
97 }
98 
TEST(OptionalArrayRefTest,DanglingPointerFix)99 TEST(OptionalArrayRefTest, DanglingPointerFix) {
100   // Ensure that the converting constructor of `OptionalArrayRef` does not
101   // create a dangling pointer when given a single value
102   ASSERT_TRUE(get_first_element(300) == 300);
103   ASSERT_TRUE(get_first_element({400}) == 400);
104 }
105