1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #include <functional>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
10*da0073e9SAndroid Build Coastguard Worker
torch_warn_once_A()11*da0073e9SAndroid Build Coastguard Worker void torch_warn_once_A() {
12*da0073e9SAndroid Build Coastguard Worker TORCH_WARN_ONCE("warn once");
13*da0073e9SAndroid Build Coastguard Worker }
14*da0073e9SAndroid Build Coastguard Worker
torch_warn_once_B()15*da0073e9SAndroid Build Coastguard Worker void torch_warn_once_B() {
16*da0073e9SAndroid Build Coastguard Worker TORCH_WARN_ONCE("warn something else once");
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker
torch_warn()19*da0073e9SAndroid Build Coastguard Worker void torch_warn() {
20*da0073e9SAndroid Build Coastguard Worker TORCH_WARN("warn multiple times");
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker
TEST(UtilsTest,WarnOnce)23*da0073e9SAndroid Build Coastguard Worker TEST(UtilsTest, WarnOnce) {
24*da0073e9SAndroid Build Coastguard Worker {
25*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker torch_warn_once_A();
28*da0073e9SAndroid Build Coastguard Worker torch_warn_once_A();
29*da0073e9SAndroid Build Coastguard Worker torch_warn_once_B();
30*da0073e9SAndroid Build Coastguard Worker torch_warn_once_B();
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
33*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
34*da0073e9SAndroid Build Coastguard Worker count_substr_occurrences(warnings.str(), "warn something else once"),
35*da0073e9SAndroid Build Coastguard Worker 1);
36*da0073e9SAndroid Build Coastguard Worker }
37*da0073e9SAndroid Build Coastguard Worker {
38*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker torch_warn();
41*da0073e9SAndroid Build Coastguard Worker torch_warn();
42*da0073e9SAndroid Build Coastguard Worker torch_warn();
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(
45*da0073e9SAndroid Build Coastguard Worker count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
46*da0073e9SAndroid Build Coastguard Worker }
47*da0073e9SAndroid Build Coastguard Worker }
48*da0073e9SAndroid Build Coastguard Worker
TEST(NoGradTest,SetsGradModeCorrectly)49*da0073e9SAndroid Build Coastguard Worker TEST(NoGradTest, SetsGradModeCorrectly) {
50*da0073e9SAndroid Build Coastguard Worker torch::manual_seed(0);
51*da0073e9SAndroid Build Coastguard Worker torch::NoGradGuard guard;
52*da0073e9SAndroid Build Coastguard Worker torch::nn::Linear model(5, 2);
53*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({10, 5}, torch::requires_grad());
54*da0073e9SAndroid Build Coastguard Worker auto y = model->forward(x);
55*da0073e9SAndroid Build Coastguard Worker torch::Tensor s = y.sum();
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker // Mimicking python API behavior:
58*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
59*da0073e9SAndroid Build Coastguard Worker s.backward(),
60*da0073e9SAndroid Build Coastguard Worker "element 0 of tensors does not require grad and does not have a grad_fn")
61*da0073e9SAndroid Build Coastguard Worker }
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker struct AutogradTest : torch::test::SeedingFixture {
AutogradTestAutogradTest64*da0073e9SAndroid Build Coastguard Worker AutogradTest() {
65*da0073e9SAndroid Build Coastguard Worker x = torch::randn({3, 3}, torch::requires_grad());
66*da0073e9SAndroid Build Coastguard Worker y = torch::randn({3, 3});
67*da0073e9SAndroid Build Coastguard Worker z = x * y;
68*da0073e9SAndroid Build Coastguard Worker }
69*da0073e9SAndroid Build Coastguard Worker torch::Tensor x, y, z;
70*da0073e9SAndroid Build Coastguard Worker };
71*da0073e9SAndroid Build Coastguard Worker
TEST_F(AutogradTest,CanTakeDerivatives)72*da0073e9SAndroid Build Coastguard Worker TEST_F(AutogradTest, CanTakeDerivatives) {
73*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones_like(z));
74*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(x.grad().allclose(y));
75*da0073e9SAndroid Build Coastguard Worker }
76*da0073e9SAndroid Build Coastguard Worker
TEST_F(AutogradTest,CanTakeDerivativesOfZeroDimTensors)77*da0073e9SAndroid Build Coastguard Worker TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
78*da0073e9SAndroid Build Coastguard Worker z.sum().backward();
79*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(x.grad().allclose(y));
80*da0073e9SAndroid Build Coastguard Worker }
81*da0073e9SAndroid Build Coastguard Worker
TEST_F(AutogradTest,CanPassCustomGradientInputs)82*da0073e9SAndroid Build Coastguard Worker TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83*da0073e9SAndroid Build Coastguard Worker z.sum().backward(torch::ones({}) * 2);
84*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(x.grad().allclose(y * 2));
85*da0073e9SAndroid Build Coastguard Worker }
86*da0073e9SAndroid Build Coastguard Worker
TEST(UtilsTest,AmbiguousOperatorDefaults)87*da0073e9SAndroid Build Coastguard Worker TEST(UtilsTest, AmbiguousOperatorDefaults) {
88*da0073e9SAndroid Build Coastguard Worker auto tmp = at::empty({}, at::kCPU);
89*da0073e9SAndroid Build Coastguard Worker at::_test_ambiguous_defaults(tmp);
90*da0073e9SAndroid Build Coastguard Worker at::_test_ambiguous_defaults(tmp, 1);
91*da0073e9SAndroid Build Coastguard Worker at::_test_ambiguous_defaults(tmp, 1, 1);
92*da0073e9SAndroid Build Coastguard Worker at::_test_ambiguous_defaults(tmp, 2, "2");
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker
get_first_element(c10::OptionalIntArrayRef arr)95*da0073e9SAndroid Build Coastguard Worker int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96*da0073e9SAndroid Build Coastguard Worker return arr.value()[0];
97*da0073e9SAndroid Build Coastguard Worker }
98*da0073e9SAndroid Build Coastguard Worker
TEST(OptionalArrayRefTest,DanglingPointerFix)99*da0073e9SAndroid Build Coastguard Worker TEST(OptionalArrayRefTest, DanglingPointerFix) {
100*da0073e9SAndroid Build Coastguard Worker // Ensure that the converting constructor of `OptionalArrayRef` does not
101*da0073e9SAndroid Build Coastguard Worker // create a dangling pointer when given a single value
102*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(get_first_element(300) == 300);
103*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(get_first_element({400}) == 400);
104*da0073e9SAndroid Build Coastguard Worker }
105