1 #include <gtest/gtest.h>
2
3 #include <torch/torch.h>
4
5 #include <test/cpp/api/support.h>
6
7 #include <functional>
8
9 using namespace torch::test;
10
torch_warn_once_A()11 void torch_warn_once_A() {
12 TORCH_WARN_ONCE("warn once");
13 }
14
torch_warn_once_B()15 void torch_warn_once_B() {
16 TORCH_WARN_ONCE("warn something else once");
17 }
18
torch_warn()19 void torch_warn() {
20 TORCH_WARN("warn multiple times");
21 }
22
TEST(UtilsTest,WarnOnce)23 TEST(UtilsTest, WarnOnce) {
24 {
25 WarningCapture warnings;
26
27 torch_warn_once_A();
28 torch_warn_once_A();
29 torch_warn_once_B();
30 torch_warn_once_B();
31
32 ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
33 ASSERT_EQ(
34 count_substr_occurrences(warnings.str(), "warn something else once"),
35 1);
36 }
37 {
38 WarningCapture warnings;
39
40 torch_warn();
41 torch_warn();
42 torch_warn();
43
44 ASSERT_EQ(
45 count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
46 }
47 }
48
TEST(NoGradTest,SetsGradModeCorrectly)49 TEST(NoGradTest, SetsGradModeCorrectly) {
50 torch::manual_seed(0);
51 torch::NoGradGuard guard;
52 torch::nn::Linear model(5, 2);
53 auto x = torch::randn({10, 5}, torch::requires_grad());
54 auto y = model->forward(x);
55 torch::Tensor s = y.sum();
56
57 // Mimicking python API behavior:
58 ASSERT_THROWS_WITH(
59 s.backward(),
60 "element 0 of tensors does not require grad and does not have a grad_fn")
61 }
62
63 struct AutogradTest : torch::test::SeedingFixture {
AutogradTestAutogradTest64 AutogradTest() {
65 x = torch::randn({3, 3}, torch::requires_grad());
66 y = torch::randn({3, 3});
67 z = x * y;
68 }
69 torch::Tensor x, y, z;
70 };
71
TEST_F(AutogradTest,CanTakeDerivatives)72 TEST_F(AutogradTest, CanTakeDerivatives) {
73 z.backward(torch::ones_like(z));
74 ASSERT_TRUE(x.grad().allclose(y));
75 }
76
TEST_F(AutogradTest,CanTakeDerivativesOfZeroDimTensors)77 TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
78 z.sum().backward();
79 ASSERT_TRUE(x.grad().allclose(y));
80 }
81
TEST_F(AutogradTest,CanPassCustomGradientInputs)82 TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83 z.sum().backward(torch::ones({}) * 2);
84 ASSERT_TRUE(x.grad().allclose(y * 2));
85 }
86
TEST(UtilsTest,AmbiguousOperatorDefaults)87 TEST(UtilsTest, AmbiguousOperatorDefaults) {
88 auto tmp = at::empty({}, at::kCPU);
89 at::_test_ambiguous_defaults(tmp);
90 at::_test_ambiguous_defaults(tmp, 1);
91 at::_test_ambiguous_defaults(tmp, 1, 1);
92 at::_test_ambiguous_defaults(tmp, 2, "2");
93 }
94
get_first_element(c10::OptionalIntArrayRef arr)95 int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96 return arr.value()[0];
97 }
98
TEST(OptionalArrayRefTest,DanglingPointerFix)99 TEST(OptionalArrayRefTest, DanglingPointerFix) {
100 // Ensure that the converting constructor of `OptionalArrayRef` does not
101 // create a dangling pointer when given a single value
102 ASSERT_TRUE(get_first_element(300) == 300);
103 ASSERT_TRUE(get_first_element({400}) == 400);
104 }
105