1*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/boxing/impl/test_helpers.h>
2*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/op_registration/op_registration.h>
5*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/FunctionsManual.h>
8*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/autograd/functions/basic_ops.h>
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker using namespace torch::autograd;
13*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker #define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
16*da0073e9SAndroid Build Coastguard Worker #define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
17*da0073e9SAndroid Build Coastguard Worker
graph_desc(std::shared_ptr<Node> node)18*da0073e9SAndroid Build Coastguard Worker std::string graph_desc(std::shared_ptr<Node> node) {
19*da0073e9SAndroid Build Coastguard Worker if (!node) {
20*da0073e9SAndroid Build Coastguard Worker return "None";
21*da0073e9SAndroid Build Coastguard Worker }
22*da0073e9SAndroid Build Coastguard Worker auto result = node->name() + "(";
23*da0073e9SAndroid Build Coastguard Worker auto next_edges = node->next_edges();
24*da0073e9SAndroid Build Coastguard Worker for (auto& edge : next_edges) {
25*da0073e9SAndroid Build Coastguard Worker result += graph_desc(edge.function);
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker return result + ")";
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker
simple_fn(const Variable & x,const Variable & y)30*da0073e9SAndroid Build Coastguard Worker Variable simple_fn(const Variable& x, const Variable& y) {
31*da0073e9SAndroid Build Coastguard Worker return x + 2 * y + x * y;
32*da0073e9SAndroid Build Coastguard Worker }
33*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,RegisterHookVoidReturnAcceptsUndefinedTensor)34*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) {
35*da0073e9SAndroid Build Coastguard Worker auto x = at::zeros({}, at::kCPU);
36*da0073e9SAndroid Build Coastguard Worker x.requires_grad_();
37*da0073e9SAndroid Build Coastguard Worker x.register_hook([](at::TensorBase x) { return; });
38*da0073e9SAndroid Build Coastguard Worker auto y = torch::autograd::UndefinedGrad().apply({x});
39*da0073e9SAndroid Build Coastguard Worker y[0].backward();
40*da0073e9SAndroid Build Coastguard Worker }
41*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,RegisterHookTensorReturnAcceptsUndefinedTensor)42*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) {
43*da0073e9SAndroid Build Coastguard Worker auto x = at::zeros({}, at::kCPU);
44*da0073e9SAndroid Build Coastguard Worker x.requires_grad_();
45*da0073e9SAndroid Build Coastguard Worker x.register_hook([](at::Tensor x) -> at::Tensor { return x; });
46*da0073e9SAndroid Build Coastguard Worker auto y = torch::autograd::UndefinedGrad().apply({x});
47*da0073e9SAndroid Build Coastguard Worker y[0].backward();
48*da0073e9SAndroid Build Coastguard Worker }
49*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,BackwardSimpleTest)50*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, BackwardSimpleTest) {
51*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({2, 2}, torch::requires_grad());
52*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({2, 2}, torch::requires_grad());
53*da0073e9SAndroid Build Coastguard Worker auto res = simple_fn(x, y);
54*da0073e9SAndroid Build Coastguard Worker backward({res.sum()}, {});
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
57*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
58*da0073e9SAndroid Build Coastguard Worker }
59*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,BackwardTest)60*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, BackwardTest) {
61*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({2, 2}, torch::requires_grad());
62*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({2, 2}, torch::requires_grad());
63*da0073e9SAndroid Build Coastguard Worker auto res = simple_fn(x, y);
64*da0073e9SAndroid Build Coastguard Worker backward({res}, {torch::ones({2, 2})}, {}, true);
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker backward({res}, {torch::ones({2, 2})});
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
69*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,GradSimpleTest)72*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, GradSimpleTest) {
73*da0073e9SAndroid Build Coastguard Worker // basic grad
74*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({2, 2}, torch::requires_grad());
75*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({2, 2}, torch::requires_grad());
76*da0073e9SAndroid Build Coastguard Worker auto res = simple_fn(x, y);
77*da0073e9SAndroid Build Coastguard Worker auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
80*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,GradTest)83*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, GradTest) {
84*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({2, 2}, torch::requires_grad());
85*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({2, 2}, torch::requires_grad());
86*da0073e9SAndroid Build Coastguard Worker auto res = simple_fn(x, y);
87*da0073e9SAndroid Build Coastguard Worker res.backward(torch::ones({2, 2}), false, true);
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker Variable x_grad = y + torch::ones({2, 2});
90*da0073e9SAndroid Build Coastguard Worker Variable y_grad = x + torch::ones({2, 2}) * 2;
91*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), x_grad);
92*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), y_grad);
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker Variable grad_sum = 2 * x.grad() + y.grad();
95*da0073e9SAndroid Build Coastguard Worker auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
98*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), x_grad);
99*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), y_grad);
100*da0073e9SAndroid Build Coastguard Worker }
101*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,GradNonLeafTest)102*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, GradNonLeafTest) {
103*da0073e9SAndroid Build Coastguard Worker Variable x_init = torch::randn({2, 2}, torch::requires_grad());
104*da0073e9SAndroid Build Coastguard Worker Variable x = x_init;
105*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({2, 2}, torch::requires_grad());
106*da0073e9SAndroid Build Coastguard Worker Variable grad_output = torch::ones({2, 2});
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker for (int i = 0; i < 5; ++i) {
109*da0073e9SAndroid Build Coastguard Worker auto res = simple_fn(x, y);
110*da0073e9SAndroid Build Coastguard Worker auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker Variable grad_x_expected = y + torch::ones({2, 2});
113*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
114*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(x.grad().defined());
115*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(y.grad().defined());
116*da0073e9SAndroid Build Coastguard Worker x = x + 0.05 * input_grads[0];
117*da0073e9SAndroid Build Coastguard Worker }
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker float val_init = simple_fn(x_init, y).sum().item().toFloat();
120*da0073e9SAndroid Build Coastguard Worker float val_final = simple_fn(x, y).sum().item().toFloat();
121*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(val_final > val_init);
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker x.backward(grad_output, false, true);
124*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(x_init.grad().defined());
125*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(y.grad().defined());
126*da0073e9SAndroid Build Coastguard Worker }
127*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,GradUnreachableTest)128*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, GradUnreachableTest) {
129*da0073e9SAndroid Build Coastguard Worker Variable x = torch::ones({1}, torch::requires_grad());
130*da0073e9SAndroid Build Coastguard Worker Variable y = torch::ones({1}, torch::requires_grad());
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker Variable z = x * 2;
133*da0073e9SAndroid Build Coastguard Worker Variable w = y * 2;
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
136*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
137*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(grad_res[1].defined());
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker // This is slightly different than the case above, because z doesn't even
140*da0073e9SAndroid Build Coastguard Worker // have a grad accumulator allocated.
141*da0073e9SAndroid Build Coastguard Worker z = torch::ones({1}, torch::requires_grad());
142*da0073e9SAndroid Build Coastguard Worker grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
145*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(grad_res[1].defined());
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker // allow_unused=False, but grads contains None inside, should throw
148*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
149*da0073e9SAndroid Build Coastguard Worker grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
150*da0073e9SAndroid Build Coastguard Worker }
151*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,GradUnreachableDiscoveryTest)152*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
153*da0073e9SAndroid Build Coastguard Worker // Test that certain nodes are not erroneously executed when an input
154*da0073e9SAndroid Build Coastguard Worker // is unreachable. See #39784
155*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
156*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable var) {
157*da0073e9SAndroid Build Coastguard Worker return var;
158*da0073e9SAndroid Build Coastguard Worker }
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
161*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
162*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
163*da0073e9SAndroid Build Coastguard Worker ADD_FAILURE() << "This node should not be executed!";
164*da0073e9SAndroid Build Coastguard Worker return grad_output;
165*da0073e9SAndroid Build Coastguard Worker }
166*da0073e9SAndroid Build Coastguard Worker };
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn(1, torch::requires_grad());
169*da0073e9SAndroid Build Coastguard Worker auto x1 = torch::randn(1);
170*da0073e9SAndroid Build Coastguard Worker auto x2 = MyFunction::apply(x + x1);
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker auto y = torch::randn(1, torch::requires_grad());
173*da0073e9SAndroid Build Coastguard Worker auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
174*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(grad_res[0].defined());
175*da0073e9SAndroid Build Coastguard Worker }
176*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,EmptyInput)177*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, EmptyInput) {
178*da0073e9SAndroid Build Coastguard Worker Variable x = torch::ones({1}, torch::requires_grad());
179*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
180*da0073e9SAndroid Build Coastguard Worker grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
181*da0073e9SAndroid Build Coastguard Worker }
182*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,RetainGrad)183*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, RetainGrad) {
184*da0073e9SAndroid Build Coastguard Worker auto input = torch::rand({1, 3}, torch::requires_grad());
185*da0073e9SAndroid Build Coastguard Worker auto h1 = input * 3;
186*da0073e9SAndroid Build Coastguard Worker auto out = (h1 * h1).sum();
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker {
189*da0073e9SAndroid Build Coastguard Worker // Warning when grad is accessed for non-leaf tensor
190*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
191*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(h1.grad().defined());
192*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
193*da0073e9SAndroid Build Coastguard Worker }
194*da0073e9SAndroid Build Coastguard Worker // It should be possible to call retain_grad() multiple times
195*da0073e9SAndroid Build Coastguard Worker h1.retain_grad();
196*da0073e9SAndroid Build Coastguard Worker h1.retain_grad();
197*da0073e9SAndroid Build Coastguard Worker {
198*da0073e9SAndroid Build Coastguard Worker // If retain_grad is true for a non-leaf tensor,
199*da0073e9SAndroid Build Coastguard Worker // there should not be any warning when grad is accessed
200*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
201*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(h1.grad().defined());
202*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
203*da0073e9SAndroid Build Coastguard Worker }
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker // Gradient should be accumulated
206*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-argument-comment)
207*da0073e9SAndroid Build Coastguard Worker out.backward({}, /*keep_graph=*/true);
208*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
209*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-argument-comment)
210*da0073e9SAndroid Build Coastguard Worker out.backward({}, /*keep_graph=*/true);
211*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker {
214*da0073e9SAndroid Build Coastguard Worker torch::NoGradGuard no_grad;
215*da0073e9SAndroid Build Coastguard Worker input.grad().zero_();
216*da0073e9SAndroid Build Coastguard Worker }
217*da0073e9SAndroid Build Coastguard Worker // It should be a no-op for leaves
218*da0073e9SAndroid Build Coastguard Worker input.retain_grad();
219*da0073e9SAndroid Build Coastguard Worker input.retain_grad();
220*da0073e9SAndroid Build Coastguard Worker out.backward();
221*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(input * 18, input.grad());
222*da0073e9SAndroid Build Coastguard Worker }
223*da0073e9SAndroid Build Coastguard Worker
TEST(AutogradAPITests,AnomalyMode)224*da0073e9SAndroid Build Coastguard Worker TEST(AutogradAPITests, AnomalyMode) {
225*da0073e9SAndroid Build Coastguard Worker // Needs to have backtrace as warning and then throw an error
226*da0073e9SAndroid Build Coastguard Worker torch::autograd::DetectAnomalyGuard detect_anomaly;
227*da0073e9SAndroid Build Coastguard Worker {
228*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
229*da0073e9SAndroid Build Coastguard Worker auto x = torch::tensor({5.0}, torch::requires_grad());
230*da0073e9SAndroid Build Coastguard Worker auto y = x * x;
231*da0073e9SAndroid Build Coastguard Worker auto z = y * y;
232*da0073e9SAndroid Build Coastguard Worker y += 1;
233*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(z.backward(), "inplace");
234*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(
235*da0073e9SAndroid Build Coastguard Worker warnings.str().find("Traceback of forward") != std::string::npos);
236*da0073e9SAndroid Build Coastguard Worker }
237*da0073e9SAndroid Build Coastguard Worker auto double_backward_produce_nan = [](bool should_throw) {
238*da0073e9SAndroid Build Coastguard Worker auto x = torch::tensor({0.0}, torch::requires_grad());
239*da0073e9SAndroid Build Coastguard Worker auto y = x.pow(1.5);
240*da0073e9SAndroid Build Coastguard Worker auto gr =
241*da0073e9SAndroid Build Coastguard Worker // NOLINTNEXTLINE(bugprone-argument-comment)
242*da0073e9SAndroid Build Coastguard Worker grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
243*da0073e9SAndroid Build Coastguard Worker if (should_throw) {
244*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
245*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
246*da0073e9SAndroid Build Coastguard Worker , "returned nan");
247*da0073e9SAndroid Build Coastguard Worker auto msgs = warnings.messages();
248*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(msgs.size(), 2);
249*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(
250*da0073e9SAndroid Build Coastguard Worker msgs[0].find("Traceback of forward call that caused the error") !=
251*da0073e9SAndroid Build Coastguard Worker std::string::npos);
252*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(
253*da0073e9SAndroid Build Coastguard Worker msgs[1].find(
254*da0073e9SAndroid Build Coastguard Worker "Traceback of forward call that induced the previous calculation") !=
255*da0073e9SAndroid Build Coastguard Worker std::string::npos);
256*da0073e9SAndroid Build Coastguard Worker } else {
257*da0073e9SAndroid Build Coastguard Worker grad({gr[0]}, {x}, {torch::tensor({0.0})});
258*da0073e9SAndroid Build Coastguard Worker }
259*da0073e9SAndroid Build Coastguard Worker };
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker double_backward_produce_nan(true);
262*da0073e9SAndroid Build Coastguard Worker {
263*da0073e9SAndroid Build Coastguard Worker torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
264*da0073e9SAndroid Build Coastguard Worker double_backward_produce_nan(false);
265*da0073e9SAndroid Build Coastguard Worker {
266*da0073e9SAndroid Build Coastguard Worker torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
267*da0073e9SAndroid Build Coastguard Worker double_backward_produce_nan(true);
268*da0073e9SAndroid Build Coastguard Worker }
269*da0073e9SAndroid Build Coastguard Worker }
270*da0073e9SAndroid Build Coastguard Worker double_backward_produce_nan(true);
271*da0073e9SAndroid Build Coastguard Worker }
272*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,CustomFunctionReturnInputAsIsAndSavesIt)273*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
274*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
275*da0073e9SAndroid Build Coastguard Worker static Variable forward(
276*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
277*da0073e9SAndroid Build Coastguard Worker Variable var1,
278*da0073e9SAndroid Build Coastguard Worker Variable var2) {
279*da0073e9SAndroid Build Coastguard Worker ctx->save_for_backward({var1, var2});
280*da0073e9SAndroid Build Coastguard Worker return var1 * var2, var1;
281*da0073e9SAndroid Build Coastguard Worker }
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
284*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
285*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
286*da0073e9SAndroid Build Coastguard Worker return {};
287*da0073e9SAndroid Build Coastguard Worker }
288*da0073e9SAndroid Build Coastguard Worker };
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
291*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
292*da0073e9SAndroid Build Coastguard Worker MyFunction::apply(x, y);
293*da0073e9SAndroid Build Coastguard Worker }
294*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,CustomFunction)295*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, CustomFunction) {
296*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
297*da0073e9SAndroid Build Coastguard Worker static Variable forward(
298*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
299*da0073e9SAndroid Build Coastguard Worker Variable var1,
300*da0073e9SAndroid Build Coastguard Worker int mul,
301*da0073e9SAndroid Build Coastguard Worker Variable var2) {
302*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["mul"] = mul;
303*da0073e9SAndroid Build Coastguard Worker ctx->save_for_backward({var1, var2});
304*da0073e9SAndroid Build Coastguard Worker return var1 + mul * var2 + var1 * var2;
305*da0073e9SAndroid Build Coastguard Worker }
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
308*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
309*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
310*da0073e9SAndroid Build Coastguard Worker int mul = ctx->saved_data["mul"].toInt();
311*da0073e9SAndroid Build Coastguard Worker auto saved = ctx->get_saved_variables();
312*da0073e9SAndroid Build Coastguard Worker auto var1 = saved[0];
313*da0073e9SAndroid Build Coastguard Worker auto var2 = saved[1];
314*da0073e9SAndroid Build Coastguard Worker variable_list output = {
315*da0073e9SAndroid Build Coastguard Worker grad_output[0] + grad_output[0] * var2,
316*da0073e9SAndroid Build Coastguard Worker Variable(),
317*da0073e9SAndroid Build Coastguard Worker grad_output[0] * mul + grad_output[0] * var1};
318*da0073e9SAndroid Build Coastguard Worker return output;
319*da0073e9SAndroid Build Coastguard Worker }
320*da0073e9SAndroid Build Coastguard Worker };
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
323*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
324*da0073e9SAndroid Build Coastguard Worker auto res = MyFunction::apply(x, 2, y);
325*da0073e9SAndroid Build Coastguard Worker auto go = torch::ones({}, torch::requires_grad());
326*da0073e9SAndroid Build Coastguard Worker res.sum().backward(go, false, true);
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
329*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
330*da0073e9SAndroid Build Coastguard Worker }
331*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,CustomFunctionWithTensorList)332*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
333*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
334*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
335*da0073e9SAndroid Build Coastguard Worker torch::autograd::variable_list vars;
336*da0073e9SAndroid Build Coastguard Worker for (const at::Tensor& tensor : tensors) {
337*da0073e9SAndroid Build Coastguard Worker vars.push_back(tensor);
338*da0073e9SAndroid Build Coastguard Worker }
339*da0073e9SAndroid Build Coastguard Worker ctx->save_for_backward(vars);
340*da0073e9SAndroid Build Coastguard Worker return tensors[0] + tensors[1] + tensors[0] * tensors[1];
341*da0073e9SAndroid Build Coastguard Worker }
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
344*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
345*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
346*da0073e9SAndroid Build Coastguard Worker auto saved = ctx->get_saved_variables();
347*da0073e9SAndroid Build Coastguard Worker auto var1 = saved[0];
348*da0073e9SAndroid Build Coastguard Worker auto var2 = saved[1];
349*da0073e9SAndroid Build Coastguard Worker variable_list output = {
350*da0073e9SAndroid Build Coastguard Worker grad_output[0] + grad_output[0] * var2,
351*da0073e9SAndroid Build Coastguard Worker grad_output[0] + grad_output[0] * var1};
352*da0073e9SAndroid Build Coastguard Worker return output;
353*da0073e9SAndroid Build Coastguard Worker }
354*da0073e9SAndroid Build Coastguard Worker };
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
357*da0073e9SAndroid Build Coastguard Worker at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
358*da0073e9SAndroid Build Coastguard Worker torch::autograd::variable_list variables = {x, y};
359*da0073e9SAndroid Build Coastguard Worker at::TensorList tensors = variables;
360*da0073e9SAndroid Build Coastguard Worker auto res = MyFunction::apply(tensors);
361*da0073e9SAndroid Build Coastguard Worker auto go = torch::ones({}, torch::requires_grad());
362*da0073e9SAndroid Build Coastguard Worker res.sum().backward(go, false, true);
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
365*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
366*da0073e9SAndroid Build Coastguard Worker }
367*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,GraphTaskTrimEdges)368*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, GraphTaskTrimEdges) {
369*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
370*da0073e9SAndroid Build Coastguard Worker static Variable forward(
371*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
372*da0073e9SAndroid Build Coastguard Worker Variable var1,
373*da0073e9SAndroid Build Coastguard Worker Variable var2,
374*da0073e9SAndroid Build Coastguard Worker int mul,
375*da0073e9SAndroid Build Coastguard Worker bool needs_input1_grad,
376*da0073e9SAndroid Build Coastguard Worker bool needs_input2_grad) {
377*da0073e9SAndroid Build Coastguard Worker // setup the expected should and should not compute idx
378*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
379*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["mul"] = mul;
382*da0073e9SAndroid Build Coastguard Worker ctx->save_for_backward({var1, var2});
383*da0073e9SAndroid Build Coastguard Worker return var1 + mul * var2 + var1 * var2;
384*da0073e9SAndroid Build Coastguard Worker }
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
387*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
388*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
389*da0073e9SAndroid Build Coastguard Worker // Test `needs_input_grad` method is working correctly.
390*da0073e9SAndroid Build Coastguard Worker // We have to test this within the backward function.
391*da0073e9SAndroid Build Coastguard Worker auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
392*da0073e9SAndroid Build Coastguard Worker auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
393*da0073e9SAndroid Build Coastguard Worker IndexRange var1_idx = {0, 1};
394*da0073e9SAndroid Build Coastguard Worker IndexRange var2_idx = {1, 2};
395*da0073e9SAndroid Build Coastguard Worker EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
396*da0073e9SAndroid Build Coastguard Worker EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
397*da0073e9SAndroid Build Coastguard Worker EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
398*da0073e9SAndroid Build Coastguard Worker EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
399*da0073e9SAndroid Build Coastguard Worker EXPECT_EQ(
400*da0073e9SAndroid Build Coastguard Worker ctx->needs_input_grad({var1_idx, var2_idx}),
401*da0073e9SAndroid Build Coastguard Worker needs_input1_grad || needs_input2_grad);
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Worker // calculate gradients
404*da0073e9SAndroid Build Coastguard Worker int mul = ctx->saved_data["mul"].toInt();
405*da0073e9SAndroid Build Coastguard Worker auto saved = ctx->get_saved_variables();
406*da0073e9SAndroid Build Coastguard Worker auto var1 = saved[0];
407*da0073e9SAndroid Build Coastguard Worker auto var2 = saved[1];
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker Variable grad_var1, grad_var2;
410*da0073e9SAndroid Build Coastguard Worker if (ctx->needs_input_grad(0)) {
411*da0073e9SAndroid Build Coastguard Worker grad_var1 = grad_output[0] + grad_output[0] * var2;
412*da0073e9SAndroid Build Coastguard Worker }
413*da0073e9SAndroid Build Coastguard Worker if (ctx->needs_input_grad(1)) {
414*da0073e9SAndroid Build Coastguard Worker grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
415*da0073e9SAndroid Build Coastguard Worker }
416*da0073e9SAndroid Build Coastguard Worker variable_list output = {
417*da0073e9SAndroid Build Coastguard Worker grad_var1,
418*da0073e9SAndroid Build Coastguard Worker grad_var2,
419*da0073e9SAndroid Build Coastguard Worker Variable(),
420*da0073e9SAndroid Build Coastguard Worker Variable(),
421*da0073e9SAndroid Build Coastguard Worker Variable(),
422*da0073e9SAndroid Build Coastguard Worker };
423*da0073e9SAndroid Build Coastguard Worker return output;
424*da0073e9SAndroid Build Coastguard Worker }
425*da0073e9SAndroid Build Coastguard Worker };
426*da0073e9SAndroid Build Coastguard Worker
427*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
428*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
429*da0073e9SAndroid Build Coastguard Worker auto go = torch::ones_like(x);
430*da0073e9SAndroid Build Coastguard Worker Variable out;
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker // grad_x
433*da0073e9SAndroid Build Coastguard Worker out = MyFunction::apply(
434*da0073e9SAndroid Build Coastguard Worker x,
435*da0073e9SAndroid Build Coastguard Worker y,
436*da0073e9SAndroid Build Coastguard Worker 2,
437*da0073e9SAndroid Build Coastguard Worker /* needs_input1_grad= */ true,
438*da0073e9SAndroid Build Coastguard Worker /* needs_input2_grad= */ false);
439*da0073e9SAndroid Build Coastguard Worker auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
440*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker // grad_y
443*da0073e9SAndroid Build Coastguard Worker out = MyFunction::apply(
444*da0073e9SAndroid Build Coastguard Worker x,
445*da0073e9SAndroid Build Coastguard Worker y,
446*da0073e9SAndroid Build Coastguard Worker 2,
447*da0073e9SAndroid Build Coastguard Worker /* needs_input1_grad= */ false,
448*da0073e9SAndroid Build Coastguard Worker /* needs_input2_grad= */ true);
449*da0073e9SAndroid Build Coastguard Worker auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
450*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker // grad_x and grad_y
453*da0073e9SAndroid Build Coastguard Worker out = MyFunction::apply(
454*da0073e9SAndroid Build Coastguard Worker x,
455*da0073e9SAndroid Build Coastguard Worker y,
456*da0073e9SAndroid Build Coastguard Worker 2,
457*da0073e9SAndroid Build Coastguard Worker /* needs_input1_grad= */ true,
458*da0073e9SAndroid Build Coastguard Worker /* needs_input2_grad= */ true);
459*da0073e9SAndroid Build Coastguard Worker auto grads = torch::autograd::grad({out}, {x, y}, {go});
460*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
461*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
462*da0073e9SAndroid Build Coastguard Worker }
463*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,FunctionReturnsInput)464*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, FunctionReturnsInput) {
465*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
466*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable var1) {
467*da0073e9SAndroid Build Coastguard Worker return var1;
468*da0073e9SAndroid Build Coastguard Worker }
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
471*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
472*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
473*da0073e9SAndroid Build Coastguard Worker return {grad_output[0] * 2};
474*da0073e9SAndroid Build Coastguard Worker }
475*da0073e9SAndroid Build Coastguard Worker };
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker Variable x(torch::ones(1, torch::requires_grad()));
478*da0073e9SAndroid Build Coastguard Worker MyFunction::apply(x).backward(torch::ones(1), true, true);
479*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
480*da0073e9SAndroid Build Coastguard Worker }
481*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,FunctionReturnsUndefined)482*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, FunctionReturnsUndefined) {
483*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
484*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable var) {
485*da0073e9SAndroid Build Coastguard Worker return var * 2;
486*da0073e9SAndroid Build Coastguard Worker }
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
489*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
490*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
491*da0073e9SAndroid Build Coastguard Worker at::Tensor undefined_tensor;
492*da0073e9SAndroid Build Coastguard Worker return {undefined_tensor};
493*da0073e9SAndroid Build Coastguard Worker }
494*da0073e9SAndroid Build Coastguard Worker };
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker auto x = torch::ones(1, torch::requires_grad());
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker MyFunction::apply(x).backward();
499*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(x.grad().defined());
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker MyFunction::apply(x.pow(2)).backward();
502*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(x.grad().defined());
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker MyFunction::apply(x).sum().backward();
505*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(x.grad().defined());
506*da0073e9SAndroid Build Coastguard Worker
507*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(torch::autograd::grad(
508*da0073e9SAndroid Build Coastguard Worker {MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
509*da0073e9SAndroid Build Coastguard Worker .defined());
510*da0073e9SAndroid Build Coastguard Worker }
511*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,MaterializeGrads)512*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, MaterializeGrads) {
513*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
514*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable var) {
515*da0073e9SAndroid Build Coastguard Worker return var;
516*da0073e9SAndroid Build Coastguard Worker }
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
519*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
520*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
521*da0073e9SAndroid Build Coastguard Worker EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
522*da0073e9SAndroid Build Coastguard Worker return grad_output;
523*da0073e9SAndroid Build Coastguard Worker }
524*da0073e9SAndroid Build Coastguard Worker };
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker auto x = torch::ones(1, torch::requires_grad());
527*da0073e9SAndroid Build Coastguard Worker UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
528*da0073e9SAndroid Build Coastguard Worker }
529*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,DontMaterializeGrads)530*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, DontMaterializeGrads) {
531*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
532*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable var) {
533*da0073e9SAndroid Build Coastguard Worker ctx->set_materialize_grads(false);
534*da0073e9SAndroid Build Coastguard Worker return var;
535*da0073e9SAndroid Build Coastguard Worker }
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
538*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
539*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
540*da0073e9SAndroid Build Coastguard Worker EXPECT_FALSE(grad_output[0].defined());
541*da0073e9SAndroid Build Coastguard Worker return grad_output;
542*da0073e9SAndroid Build Coastguard Worker }
543*da0073e9SAndroid Build Coastguard Worker };
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker auto x = torch::ones(1, torch::requires_grad());
546*da0073e9SAndroid Build Coastguard Worker UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
547*da0073e9SAndroid Build Coastguard Worker }
548*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,NoGradCustomFunction)549*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, NoGradCustomFunction) {
550*da0073e9SAndroid Build Coastguard Worker // Custom Function should respect grad mode
551*da0073e9SAndroid Build Coastguard Worker struct MyOp : public Function<MyOp> {
552*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable x) {
553*da0073e9SAndroid Build Coastguard Worker return x + 1;
554*da0073e9SAndroid Build Coastguard Worker }
555*da0073e9SAndroid Build Coastguard Worker
556*da0073e9SAndroid Build Coastguard Worker static variable_list backward(AutogradContext* ctx, variable_list dy) {
557*da0073e9SAndroid Build Coastguard Worker return dy;
558*da0073e9SAndroid Build Coastguard Worker }
559*da0073e9SAndroid Build Coastguard Worker };
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker auto x = torch::ones({5, 5}, torch::requires_grad());
562*da0073e9SAndroid Build Coastguard Worker {
563*da0073e9SAndroid Build Coastguard Worker at::NoGradGuard no_grad;
564*da0073e9SAndroid Build Coastguard Worker auto y = MyOp::apply(x);
565*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(y.requires_grad());
566*da0073e9SAndroid Build Coastguard Worker }
567*da0073e9SAndroid Build Coastguard Worker }
568*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,MarkDirty)569*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, MarkDirty) {
570*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
571*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable v) {
572*da0073e9SAndroid Build Coastguard Worker // Change the value inplace
573*da0073e9SAndroid Build Coastguard Worker auto v_data = v.data_ptr<float>();
574*da0073e9SAndroid Build Coastguard Worker v_data[0] = 2;
575*da0073e9SAndroid Build Coastguard Worker ctx->mark_dirty({v});
576*da0073e9SAndroid Build Coastguard Worker return v;
577*da0073e9SAndroid Build Coastguard Worker }
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
580*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
581*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
582*da0073e9SAndroid Build Coastguard Worker return {(grad_output[0] * 2.0)};
583*da0073e9SAndroid Build Coastguard Worker }
584*da0073e9SAndroid Build Coastguard Worker };
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker // Clone here because modifying leafs inplace is not allowed
587*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
588*da0073e9SAndroid Build Coastguard Worker auto version_before = x._version();
589*da0073e9SAndroid Build Coastguard Worker auto out = MyFunction::apply(x);
590*da0073e9SAndroid Build Coastguard Worker auto version_after = x._version();
591*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(version_after >= (version_before + 1));
592*da0073e9SAndroid Build Coastguard Worker out.sum().backward();
593*da0073e9SAndroid Build Coastguard Worker }
594*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,MarkNonDifferentiable)595*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, MarkNonDifferentiable) {
596*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
597*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable v) {
598*da0073e9SAndroid Build Coastguard Worker Variable output = v > 0;
599*da0073e9SAndroid Build Coastguard Worker ctx->mark_non_differentiable({output});
600*da0073e9SAndroid Build Coastguard Worker return output;
601*da0073e9SAndroid Build Coastguard Worker }
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
604*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
605*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
606*da0073e9SAndroid Build Coastguard Worker return {(grad_output[0] * 0.0)};
607*da0073e9SAndroid Build Coastguard Worker }
608*da0073e9SAndroid Build Coastguard Worker };
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad());
611*da0073e9SAndroid Build Coastguard Worker auto mask = MyFunction::apply(x);
612*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(mask.requires_grad());
613*da0073e9SAndroid Build Coastguard Worker auto y = x.masked_fill(mask, 0);
614*da0073e9SAndroid Build Coastguard Worker y.sum().backward();
615*da0073e9SAndroid Build Coastguard Worker }
616*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,MarkNonDifferentiableMixed)617*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
618*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
619*da0073e9SAndroid Build Coastguard Worker static variable_list forward(AutogradContext* ctx, Variable input) {
620*da0073e9SAndroid Build Coastguard Worker Variable a = input + 1;
621*da0073e9SAndroid Build Coastguard Worker Variable b = input + 2;
622*da0073e9SAndroid Build Coastguard Worker ctx->mark_non_differentiable({a});
623*da0073e9SAndroid Build Coastguard Worker return {a, b};
624*da0073e9SAndroid Build Coastguard Worker }
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
627*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
628*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
629*da0073e9SAndroid Build Coastguard Worker const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
630*da0073e9SAndroid Build Coastguard Worker EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
631*da0073e9SAndroid Build Coastguard Worker EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
632*da0073e9SAndroid Build Coastguard Worker return {grad_b};
633*da0073e9SAndroid Build Coastguard Worker }
634*da0073e9SAndroid Build Coastguard Worker };
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad());
637*da0073e9SAndroid Build Coastguard Worker auto out = MyFunction::apply(x);
638*da0073e9SAndroid Build Coastguard Worker
639*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(out[0].requires_grad());
640*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(out[1].requires_grad());
641*da0073e9SAndroid Build Coastguard Worker out[1].sum().backward();
642*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
643*da0073e9SAndroid Build Coastguard Worker }
644*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,MarkNonDifferentiableNone)645*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
646*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
647*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable input) {
648*da0073e9SAndroid Build Coastguard Worker auto output = input.clone();
649*da0073e9SAndroid Build Coastguard Worker ctx->mark_non_differentiable({output});
650*da0073e9SAndroid Build Coastguard Worker return output;
651*da0073e9SAndroid Build Coastguard Worker }
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
654*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
655*da0073e9SAndroid Build Coastguard Worker variable_list grad_outputs) {
656*da0073e9SAndroid Build Coastguard Worker return {};
657*da0073e9SAndroid Build Coastguard Worker }
658*da0073e9SAndroid Build Coastguard Worker };
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad());
661*da0073e9SAndroid Build Coastguard Worker auto r = MyFunction::apply(x * x);
662*da0073e9SAndroid Build Coastguard Worker (r * x).sum().backward();
663*da0073e9SAndroid Build Coastguard Worker }
664*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,ReturnLeafInplace)665*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, ReturnLeafInplace) {
666*da0073e9SAndroid Build Coastguard Worker struct Inplace : public Function<Inplace> {
667*da0073e9SAndroid Build Coastguard Worker static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
668*da0073e9SAndroid Build Coastguard Worker ctx->mark_dirty({a});
669*da0073e9SAndroid Build Coastguard Worker return {a.add_(b), b + 2};
670*da0073e9SAndroid Build Coastguard Worker }
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
673*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
674*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
675*da0073e9SAndroid Build Coastguard Worker return {grad_output[0], grad_output[0] + grad_output[1]};
676*da0073e9SAndroid Build Coastguard Worker }
677*da0073e9SAndroid Build Coastguard Worker };
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5});
680*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker auto out = Inplace::apply(x, y);
683*da0073e9SAndroid Build Coastguard Worker auto& q = out[0];
684*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::equal(q, x));
685*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(q.requires_grad());
686*da0073e9SAndroid Build Coastguard Worker q.sum().backward();
687*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
688*da0073e9SAndroid Build Coastguard Worker }
689*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,ReturnDuplicateInplace)690*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, ReturnDuplicateInplace) {
691*da0073e9SAndroid Build Coastguard Worker struct DoubleInplace : public Function<DoubleInplace> {
692*da0073e9SAndroid Build Coastguard Worker static variable_list forward(AutogradContext* ctx, Variable x) {
693*da0073e9SAndroid Build Coastguard Worker x.mul_(2);
694*da0073e9SAndroid Build Coastguard Worker ctx->mark_dirty({x});
695*da0073e9SAndroid Build Coastguard Worker return {x, x};
696*da0073e9SAndroid Build Coastguard Worker }
697*da0073e9SAndroid Build Coastguard Worker
698*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
699*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctsx,
700*da0073e9SAndroid Build Coastguard Worker variable_list grad_outputs) {
701*da0073e9SAndroid Build Coastguard Worker return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
702*da0073e9SAndroid Build Coastguard Worker }
703*da0073e9SAndroid Build Coastguard Worker };
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad());
706*da0073e9SAndroid Build Coastguard Worker
707*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
708*da0073e9SAndroid Build Coastguard Worker DoubleInplace::apply(x), "leaf Variable that requires grad");
709*da0073e9SAndroid Build Coastguard Worker // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
710*da0073e9SAndroid Build Coastguard Worker // output");
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker auto out = DoubleInplace::apply(x.clone());
713*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::equal(out[0], out[1]));
714*da0073e9SAndroid Build Coastguard Worker }
715*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,ReturnDuplicate)716*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, ReturnDuplicate) {
717*da0073e9SAndroid Build Coastguard Worker struct DoubleDuplicate : public Function<DoubleDuplicate> {
718*da0073e9SAndroid Build Coastguard Worker static variable_list forward(AutogradContext* ctx, Variable x) {
719*da0073e9SAndroid Build Coastguard Worker auto output = x * 2;
720*da0073e9SAndroid Build Coastguard Worker return {output, output};
721*da0073e9SAndroid Build Coastguard Worker }
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
724*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
725*da0073e9SAndroid Build Coastguard Worker variable_list grad_outputs) {
726*da0073e9SAndroid Build Coastguard Worker return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
727*da0073e9SAndroid Build Coastguard Worker }
728*da0073e9SAndroid Build Coastguard Worker };
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad());
731*da0073e9SAndroid Build Coastguard Worker auto out = DoubleDuplicate::apply(x);
732*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::equal(out[0], out[1]));
733*da0073e9SAndroid Build Coastguard Worker }
734*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,SaveEmptyForBackward)735*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, SaveEmptyForBackward) {
736*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
737*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable input) {
738*da0073e9SAndroid Build Coastguard Worker ctx->save_for_backward({Variable(), input, Variable()});
739*da0073e9SAndroid Build Coastguard Worker return input * input;
740*da0073e9SAndroid Build Coastguard Worker }
741*da0073e9SAndroid Build Coastguard Worker
742*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
743*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
744*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
745*da0073e9SAndroid Build Coastguard Worker auto saved = ctx->get_saved_variables();
746*da0073e9SAndroid Build Coastguard Worker EXPECT_FALSE(saved[0].defined());
747*da0073e9SAndroid Build Coastguard Worker EXPECT_FALSE(saved[2].defined());
748*da0073e9SAndroid Build Coastguard Worker return {saved[1] * 2 * grad_output[0]};
749*da0073e9SAndroid Build Coastguard Worker }
750*da0073e9SAndroid Build Coastguard Worker };
751*da0073e9SAndroid Build Coastguard Worker
752*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
753*da0073e9SAndroid Build Coastguard Worker auto y = MyFunction::apply(x);
754*da0073e9SAndroid Build Coastguard Worker y.sum().backward();
755*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
756*da0073e9SAndroid Build Coastguard Worker }
757*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,InvalidGradients)758*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, InvalidGradients) {
759*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
760*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable x) {
761*da0073e9SAndroid Build Coastguard Worker return x * 2;
762*da0073e9SAndroid Build Coastguard Worker }
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
765*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctsx,
766*da0073e9SAndroid Build Coastguard Worker variable_list grad_outputs) {
767*da0073e9SAndroid Build Coastguard Worker return {
768*da0073e9SAndroid Build Coastguard Worker torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
769*da0073e9SAndroid Build Coastguard Worker }
770*da0073e9SAndroid Build Coastguard Worker };
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker auto input1 =
773*da0073e9SAndroid Build Coastguard Worker torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
774*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
775*da0073e9SAndroid Build Coastguard Worker MyFunction::apply(input1).sum().backward(), "expected shape");
776*da0073e9SAndroid Build Coastguard Worker auto input2 =
777*da0073e9SAndroid Build Coastguard Worker torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
778*da0073e9SAndroid Build Coastguard Worker }
779*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,NoGradInput)780*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, NoGradInput) {
781*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
782*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext*, Variable x) {
783*da0073e9SAndroid Build Coastguard Worker return x;
784*da0073e9SAndroid Build Coastguard Worker }
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
787*da0073e9SAndroid Build Coastguard Worker AutogradContext*,
788*da0073e9SAndroid Build Coastguard Worker variable_list grad_outputs) {
789*da0073e9SAndroid Build Coastguard Worker return grad_outputs;
790*da0073e9SAndroid Build Coastguard Worker }
791*da0073e9SAndroid Build Coastguard Worker };
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
794*da0073e9SAndroid Build Coastguard Worker Variable y;
795*da0073e9SAndroid Build Coastguard Worker {
796*da0073e9SAndroid Build Coastguard Worker at::NoGradGuard no_grad;
797*da0073e9SAndroid Build Coastguard Worker y = MyFunction::apply(x);
798*da0073e9SAndroid Build Coastguard Worker }
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(x.requires_grad());
801*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(y.grad_fn());
802*da0073e9SAndroid Build Coastguard Worker }
803*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,TooManyGrads)804*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, TooManyGrads) {
805*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
806*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext*, Variable input) {
807*da0073e9SAndroid Build Coastguard Worker return input;
808*da0073e9SAndroid Build Coastguard Worker }
809*da0073e9SAndroid Build Coastguard Worker
810*da0073e9SAndroid Build Coastguard Worker static variable_list backward(AutogradContext*, variable_list grad_output) {
811*da0073e9SAndroid Build Coastguard Worker grad_output.insert(grad_output.end(), {Variable(), Variable()});
812*da0073e9SAndroid Build Coastguard Worker return grad_output;
813*da0073e9SAndroid Build Coastguard Worker }
814*da0073e9SAndroid Build Coastguard Worker };
815*da0073e9SAndroid Build Coastguard Worker }
816*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,DepNoGrad)817*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, DepNoGrad) {
818*da0073e9SAndroid Build Coastguard Worker struct F1 : public Function<F1> {
819*da0073e9SAndroid Build Coastguard Worker static variable_list forward(AutogradContext* ctx, Variable input) {
820*da0073e9SAndroid Build Coastguard Worker auto out = torch::randn(input.sizes());
821*da0073e9SAndroid Build Coastguard Worker ctx->mark_non_differentiable({out});
822*da0073e9SAndroid Build Coastguard Worker return {input, out};
823*da0073e9SAndroid Build Coastguard Worker }
824*da0073e9SAndroid Build Coastguard Worker
825*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
826*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
827*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
828*da0073e9SAndroid Build Coastguard Worker return {grad_output[0]};
829*da0073e9SAndroid Build Coastguard Worker }
830*da0073e9SAndroid Build Coastguard Worker };
831*da0073e9SAndroid Build Coastguard Worker
832*da0073e9SAndroid Build Coastguard Worker struct F2 : public Function<F2> {
833*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext*, Variable input, Variable ignore) {
834*da0073e9SAndroid Build Coastguard Worker return input;
835*da0073e9SAndroid Build Coastguard Worker }
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker static variable_list backward(AutogradContext*, variable_list grad_output) {
838*da0073e9SAndroid Build Coastguard Worker return {grad_output[0], Variable()};
839*da0073e9SAndroid Build Coastguard Worker }
840*da0073e9SAndroid Build Coastguard Worker };
841*da0073e9SAndroid Build Coastguard Worker
842*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn(5, torch::requires_grad());
843*da0073e9SAndroid Build Coastguard Worker auto out = F1::apply(x);
844*da0073e9SAndroid Build Coastguard Worker Variable &a = out[0], &b = out[1];
845*da0073e9SAndroid Build Coastguard Worker b = b + 1; // Separate F1 and F2 by another operation
846*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(a.requires_grad());
847*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(b.requires_grad());
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker auto c = F2::apply(a, b);
850*da0073e9SAndroid Build Coastguard Worker c.backward(torch::ones(c.sizes()), false, false);
851*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
852*da0073e9SAndroid Build Coastguard Worker }
853*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,Reentrant)854*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, Reentrant) {
855*da0073e9SAndroid Build Coastguard Worker static Variable y_data = torch::randn({2, 2});
856*da0073e9SAndroid Build Coastguard Worker struct Reenter : public Function<Reenter> {
857*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable input) {
858*da0073e9SAndroid Build Coastguard Worker Variable output;
859*da0073e9SAndroid Build Coastguard Worker {
860*da0073e9SAndroid Build Coastguard Worker at::AutoGradMode enable_grad(true);
861*da0073e9SAndroid Build Coastguard Worker auto x = make_variable(input.tensor_data(), true);
862*da0073e9SAndroid Build Coastguard Worker auto y = make_variable(y_data.tensor_data(), true);
863*da0073e9SAndroid Build Coastguard Worker output = x * y;
864*da0073e9SAndroid Build Coastguard Worker
865*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["x"] = x;
866*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["y"] = y;
867*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["output_var"] = output;
868*da0073e9SAndroid Build Coastguard Worker }
869*da0073e9SAndroid Build Coastguard Worker return output.detach();
870*da0073e9SAndroid Build Coastguard Worker }
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
873*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
874*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
875*da0073e9SAndroid Build Coastguard Worker {
876*da0073e9SAndroid Build Coastguard Worker at::AutoGradMode enable_grad(true);
877*da0073e9SAndroid Build Coastguard Worker auto out = ctx->saved_data["output_var"].toTensor();
878*da0073e9SAndroid Build Coastguard Worker out.sum().backward();
879*da0073e9SAndroid Build Coastguard Worker }
880*da0073e9SAndroid Build Coastguard Worker return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
881*da0073e9SAndroid Build Coastguard Worker }
882*da0073e9SAndroid Build Coastguard Worker };
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({2, 2}, torch::requires_grad());
885*da0073e9SAndroid Build Coastguard Worker auto out = Reenter::apply(x);
886*da0073e9SAndroid Build Coastguard Worker out.sum().backward();
887*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), y_data);
888*da0073e9SAndroid Build Coastguard Worker }
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
891*da0073e9SAndroid Build Coastguard Worker // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
TEST(CustomAutogradTest,DeepReentrant)892*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, DeepReentrant) {
893*da0073e9SAndroid Build Coastguard Worker struct DeepReenter : public Function<DeepReenter> {
894*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable x) {
895*da0073e9SAndroid Build Coastguard Worker {
896*da0073e9SAndroid Build Coastguard Worker at::AutoGradMode enable_grad(true);
897*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
898*da0073e9SAndroid Build Coastguard Worker }
899*da0073e9SAndroid Build Coastguard Worker return ctx->saved_data["x"].toTensor().detach();
900*da0073e9SAndroid Build Coastguard Worker }
901*da0073e9SAndroid Build Coastguard Worker
902*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
903*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
904*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
905*da0073e9SAndroid Build Coastguard Worker if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
906*da0073e9SAndroid Build Coastguard Worker return grad_output;
907*da0073e9SAndroid Build Coastguard Worker }
908*da0073e9SAndroid Build Coastguard Worker {
909*da0073e9SAndroid Build Coastguard Worker at::AutoGradMode enable_grad(true);
910*da0073e9SAndroid Build Coastguard Worker apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
911*da0073e9SAndroid Build Coastguard Worker return grad_output;
912*da0073e9SAndroid Build Coastguard Worker }
913*da0073e9SAndroid Build Coastguard Worker }
914*da0073e9SAndroid Build Coastguard Worker };
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker // This should not stack overflow
917*da0073e9SAndroid Build Coastguard Worker auto v =
918*da0073e9SAndroid Build Coastguard Worker torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
919*da0073e9SAndroid Build Coastguard Worker DeepReenter::apply(v).sum().backward();
920*da0073e9SAndroid Build Coastguard Worker }
921*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,ReentrantPriority)922*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, ReentrantPriority) {
923*da0073e9SAndroid Build Coastguard Worker static std::vector<int> order;
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker struct MyFunction : public Function<MyFunction> {
926*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext*, Variable x) {
927*da0073e9SAndroid Build Coastguard Worker return x;
928*da0073e9SAndroid Build Coastguard Worker }
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker static variable_list backward(AutogradContext*, variable_list grad) {
931*da0073e9SAndroid Build Coastguard Worker order.push_back(0);
932*da0073e9SAndroid Build Coastguard Worker return grad;
933*da0073e9SAndroid Build Coastguard Worker }
934*da0073e9SAndroid Build Coastguard Worker };
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker struct Reenter : public Function<Reenter> {
937*da0073e9SAndroid Build Coastguard Worker static Variable forward(AutogradContext* ctx, Variable x) {
938*da0073e9SAndroid Build Coastguard Worker {
939*da0073e9SAndroid Build Coastguard Worker at::AutoGradMode enable_grad(true);
940*da0073e9SAndroid Build Coastguard Worker ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
941*da0073e9SAndroid Build Coastguard Worker }
942*da0073e9SAndroid Build Coastguard Worker return ctx->saved_data["x"].toTensor().detach();
943*da0073e9SAndroid Build Coastguard Worker }
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker static variable_list backward(
946*da0073e9SAndroid Build Coastguard Worker AutogradContext* ctx,
947*da0073e9SAndroid Build Coastguard Worker variable_list grad_output) {
948*da0073e9SAndroid Build Coastguard Worker order.push_back(1);
949*da0073e9SAndroid Build Coastguard Worker if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
950*da0073e9SAndroid Build Coastguard Worker return grad_output;
951*da0073e9SAndroid Build Coastguard Worker }
952*da0073e9SAndroid Build Coastguard Worker {
953*da0073e9SAndroid Build Coastguard Worker at::AutoGradMode enable_grad(true);
954*da0073e9SAndroid Build Coastguard Worker apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
955*da0073e9SAndroid Build Coastguard Worker return grad_output;
956*da0073e9SAndroid Build Coastguard Worker }
957*da0073e9SAndroid Build Coastguard Worker }
958*da0073e9SAndroid Build Coastguard Worker };
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker auto a = MyFunction::apply(
961*da0073e9SAndroid Build Coastguard Worker torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
962*da0073e9SAndroid Build Coastguard Worker auto b = Reenter::apply(
963*da0073e9SAndroid Build Coastguard Worker torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
964*da0073e9SAndroid Build Coastguard Worker auto v = a * b;
965*da0073e9SAndroid Build Coastguard Worker v.backward();
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker // All the reentrant tasks should be prioritized over the MyFunction backward
968*da0073e9SAndroid Build Coastguard Worker // task.
969*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(order.size(), 10);
970*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
971*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(order.back(), 0);
972*da0073e9SAndroid Build Coastguard Worker // Clear static variable in case test get executed in a loop
973*da0073e9SAndroid Build Coastguard Worker order.clear();
974*da0073e9SAndroid Build Coastguard Worker }
975*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,Hooks)976*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, Hooks) {
977*da0073e9SAndroid Build Coastguard Worker Variable x = torch::ones({5, 5}, torch::requires_grad());
978*da0073e9SAndroid Build Coastguard Worker Variable y = torch::ones({5, 5}) * 4;
979*da0073e9SAndroid Build Coastguard Worker y.set_requires_grad(true);
980*da0073e9SAndroid Build Coastguard Worker
981*da0073e9SAndroid Build Coastguard Worker int counter = 0;
982*da0073e9SAndroid Build Coastguard Worker
983*da0073e9SAndroid Build Coastguard Worker std::function<void(int, Variable)> bw_hook(
984*da0073e9SAndroid Build Coastguard Worker [&counter](int inc, Variable grad) { counter += inc; });
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker Variable z = x * x + x * 2 + x * y + y;
987*da0073e9SAndroid Build Coastguard Worker x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
988*da0073e9SAndroid Build Coastguard Worker auto hook_1 =
989*da0073e9SAndroid Build Coastguard Worker z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
990*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), true, true);
991*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(counter, 1);
992*da0073e9SAndroid Build Coastguard Worker
993*da0073e9SAndroid Build Coastguard Worker auto hook_2 =
994*da0073e9SAndroid Build Coastguard Worker z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
995*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), true, true);
996*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(counter, 4);
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker z.remove_hook(hook_2);
999*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), true, true);
1000*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(counter, 5);
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker std::function<Variable(Variable)> bw_hook_modify(
1003*da0073e9SAndroid Build Coastguard Worker [](Variable grad) { return grad.mul(2); });
1004*da0073e9SAndroid Build Coastguard Worker
1005*da0073e9SAndroid Build Coastguard Worker z.remove_hook(hook_1);
1006*da0073e9SAndroid Build Coastguard Worker z.register_hook(bw_hook_modify);
1007*da0073e9SAndroid Build Coastguard Worker y.grad().zero_();
1008*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), true, false);
1009*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker y.register_hook(bw_hook_modify);
1012*da0073e9SAndroid Build Coastguard Worker y.grad().zero_();
1013*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), false, false);
1014*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
1017*da0073e9SAndroid Build Coastguard Worker }
1018*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,HooksInplace)1019*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, HooksInplace) {
1020*da0073e9SAndroid Build Coastguard Worker auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker int hook1_count = 0;
1023*da0073e9SAndroid Build Coastguard Worker auto hook1 = ([&hook1_count](Variable grad) {
1024*da0073e9SAndroid Build Coastguard Worker hook1_count++;
1025*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1026*da0073e9SAndroid Build Coastguard Worker });
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker int hook2_count = 0;
1029*da0073e9SAndroid Build Coastguard Worker auto hook2 = ([&hook2_count](Variable grad) {
1030*da0073e9SAndroid Build Coastguard Worker hook2_count++;
1031*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1032*da0073e9SAndroid Build Coastguard Worker });
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook1);
1035*da0073e9SAndroid Build Coastguard Worker a.mul_(2);
1036*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook2);
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker auto out = (a + 1).sum();
1039*da0073e9SAndroid Build Coastguard Worker out.backward();
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook1_count, 1);
1042*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook2_count, 1);
1043*da0073e9SAndroid Build Coastguard Worker }
1044*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,HooksInplaceWithRetainsGrad)1045*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
1046*da0073e9SAndroid Build Coastguard Worker auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker int hook1_count = 0;
1049*da0073e9SAndroid Build Coastguard Worker auto hook1 = ([&hook1_count](Variable grad) {
1050*da0073e9SAndroid Build Coastguard Worker hook1_count++;
1051*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1052*da0073e9SAndroid Build Coastguard Worker });
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker int hook2_count = 0;
1055*da0073e9SAndroid Build Coastguard Worker auto hook2 = ([&hook2_count](Variable grad) {
1056*da0073e9SAndroid Build Coastguard Worker hook2_count++;
1057*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1058*da0073e9SAndroid Build Coastguard Worker });
1059*da0073e9SAndroid Build Coastguard Worker
1060*da0073e9SAndroid Build Coastguard Worker int hook3_count = 0;
1061*da0073e9SAndroid Build Coastguard Worker auto hook3 = ([&hook3_count](Variable grad) {
1062*da0073e9SAndroid Build Coastguard Worker hook3_count++;
1063*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1064*da0073e9SAndroid Build Coastguard Worker });
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook1);
1067*da0073e9SAndroid Build Coastguard Worker a.retain_grad();
1068*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook2);
1069*da0073e9SAndroid Build Coastguard Worker
1070*da0073e9SAndroid Build Coastguard Worker a.mul_(2);
1071*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook3);
1072*da0073e9SAndroid Build Coastguard Worker
1073*da0073e9SAndroid Build Coastguard Worker auto out = (a + 1).sum();
1074*da0073e9SAndroid Build Coastguard Worker out.backward();
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook1_count, 1);
1077*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook2_count, 1);
1078*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook3_count, 1);
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(a.retains_grad());
1081*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1082*da0073e9SAndroid Build Coastguard Worker }
1083*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,HooksInplaceTwiceWithRetainsGrad)1084*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
1085*da0073e9SAndroid Build Coastguard Worker auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1086*da0073e9SAndroid Build Coastguard Worker
1087*da0073e9SAndroid Build Coastguard Worker int hook1_count = 0;
1088*da0073e9SAndroid Build Coastguard Worker auto hook1 = ([&hook1_count](Variable grad) {
1089*da0073e9SAndroid Build Coastguard Worker hook1_count++;
1090*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1091*da0073e9SAndroid Build Coastguard Worker });
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker int hook2_count = 0;
1094*da0073e9SAndroid Build Coastguard Worker auto hook2 = ([&hook2_count](Variable grad) {
1095*da0073e9SAndroid Build Coastguard Worker hook2_count++;
1096*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1097*da0073e9SAndroid Build Coastguard Worker });
1098*da0073e9SAndroid Build Coastguard Worker
1099*da0073e9SAndroid Build Coastguard Worker int hook3_count = 0;
1100*da0073e9SAndroid Build Coastguard Worker auto hook3 = ([&hook3_count](Variable grad) {
1101*da0073e9SAndroid Build Coastguard Worker hook3_count++;
1102*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1103*da0073e9SAndroid Build Coastguard Worker });
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook1);
1106*da0073e9SAndroid Build Coastguard Worker a.retain_grad();
1107*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook2);
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker a.mul_(2);
1110*da0073e9SAndroid Build Coastguard Worker a.mul_(2);
1111*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook3);
1112*da0073e9SAndroid Build Coastguard Worker
1113*da0073e9SAndroid Build Coastguard Worker auto out = (a + 1).sum();
1114*da0073e9SAndroid Build Coastguard Worker out.backward();
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook1_count, 1);
1117*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook2_count, 1);
1118*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(hook3_count, 1);
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(a.retains_grad());
1121*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1122*da0073e9SAndroid Build Coastguard Worker }
1123*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,HookNone)1124*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, HookNone) {
1125*da0073e9SAndroid Build Coastguard Worker struct NoneGradientFunction : public Function<NoneGradientFunction> {
1126*da0073e9SAndroid Build Coastguard Worker static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
1127*da0073e9SAndroid Build Coastguard Worker return {x, y};
1128*da0073e9SAndroid Build Coastguard Worker }
1129*da0073e9SAndroid Build Coastguard Worker
1130*da0073e9SAndroid Build Coastguard Worker static variable_list backward(AutogradContext* ctx, variable_list grad) {
1131*da0073e9SAndroid Build Coastguard Worker return {grad[0], Variable()};
1132*da0073e9SAndroid Build Coastguard Worker }
1133*da0073e9SAndroid Build Coastguard Worker };
1134*da0073e9SAndroid Build Coastguard Worker
1135*da0073e9SAndroid Build Coastguard Worker bool was_called = false;
1136*da0073e9SAndroid Build Coastguard Worker
1137*da0073e9SAndroid Build Coastguard Worker auto hook = ([&was_called](Variable grad) {
1138*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(grad.defined());
1139*da0073e9SAndroid Build Coastguard Worker was_called = true;
1140*da0073e9SAndroid Build Coastguard Worker });
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker auto x = torch::randn({5, 5}, torch::requires_grad());
1143*da0073e9SAndroid Build Coastguard Worker auto y = torch::randn({5, 5});
1144*da0073e9SAndroid Build Coastguard Worker
1145*da0073e9SAndroid Build Coastguard Worker auto out = NoneGradientFunction::apply(x, y);
1146*da0073e9SAndroid Build Coastguard Worker Variable rx = x[0], ry = x[1];
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker rx.register_hook(hook);
1149*da0073e9SAndroid Build Coastguard Worker ry.register_hook(hook);
1150*da0073e9SAndroid Build Coastguard Worker (rx + ry).sum().backward();
1151*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(was_called);
1152*da0073e9SAndroid Build Coastguard Worker }
1153*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,BackwardWithInputs)1154*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, BackwardWithInputs) {
1155*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
1156*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
1157*da0073e9SAndroid Build Coastguard Worker Variable z = x * x + x * y + y * y;
1158*da0073e9SAndroid Build Coastguard Worker Variable x_grad_expected = 2 * x + y;
1159*da0073e9SAndroid Build Coastguard Worker Variable y_grad_expected = x + 2 * y;
1160*da0073e9SAndroid Build Coastguard Worker
1161*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), false, false, {x});
1162*da0073e9SAndroid Build Coastguard Worker
1163*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1164*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(y.grad().defined());
1165*da0073e9SAndroid Build Coastguard Worker }
1166*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,BackwardWithEmptyInputs)1167*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
1168*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
1169*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
1170*da0073e9SAndroid Build Coastguard Worker Variable z = x * x + x * y + y * y;
1171*da0073e9SAndroid Build Coastguard Worker Variable x_grad_expected = 2 * x + y;
1172*da0073e9SAndroid Build Coastguard Worker Variable y_grad_expected = x + 2 * y;
1173*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1174*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
1175*da0073e9SAndroid Build Coastguard Worker "cannot be empty");
1176*da0073e9SAndroid Build Coastguard Worker }
1177*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,BackwardWithNonLeafInputs)1178*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
1179*da0073e9SAndroid Build Coastguard Worker Variable x = torch::randn({5, 5}, torch::requires_grad());
1180*da0073e9SAndroid Build Coastguard Worker Variable y = torch::randn({5, 5}, torch::requires_grad());
1181*da0073e9SAndroid Build Coastguard Worker Variable z = x * x;
1182*da0073e9SAndroid Build Coastguard Worker Variable w = y * z + x * y + y * y;
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker Variable x_grad_expected = 2 * x * y + y;
1185*da0073e9SAndroid Build Coastguard Worker Variable z_grad_expected = y;
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1190*da0073e9SAndroid Build Coastguard Worker ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
1191*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(y.grad().defined());
1192*da0073e9SAndroid Build Coastguard Worker }
1193*da0073e9SAndroid Build Coastguard Worker
TEST(CustomAutogradTest,BackwardWithCreateGraphWarns)1194*da0073e9SAndroid Build Coastguard Worker TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
1195*da0073e9SAndroid Build Coastguard Worker c10::WarningUtils::WarnAlways guard(true);
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
1198*da0073e9SAndroid Build Coastguard Worker auto z = x * x;
1199*da0073e9SAndroid Build Coastguard Worker {
1200*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
1201*da0073e9SAndroid Build Coastguard Worker z.backward(torch::ones({5, 5}), std::nullopt, true);
1202*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(
1203*da0073e9SAndroid Build Coastguard Worker warnings.str().find("Using backward() with create_graph=True") !=
1204*da0073e9SAndroid Build Coastguard Worker std::string::npos);
1205*da0073e9SAndroid Build Coastguard Worker }
1206*da0073e9SAndroid Build Coastguard Worker
1207*da0073e9SAndroid Build Coastguard Worker {
1208*da0073e9SAndroid Build Coastguard Worker WarningCapture warnings;
1209*da0073e9SAndroid Build Coastguard Worker torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true);
1210*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(
1211*da0073e9SAndroid Build Coastguard Worker warnings.str().find("Using backward() with create_graph=True") !=
1212*da0073e9SAndroid Build Coastguard Worker std::string::npos);
1213*da0073e9SAndroid Build Coastguard Worker }
1214*da0073e9SAndroid Build Coastguard Worker }
1215*da0073e9SAndroid Build Coastguard Worker
1216*da0073e9SAndroid Build Coastguard Worker /**
1217*da0073e9SAndroid Build Coastguard Worker * Tests for AutogradNotImplementedFallback
1218*da0073e9SAndroid Build Coastguard Worker * - Check that we created the NotImplemented kernel when inputs require grad
1219*da0073e9SAndroid Build Coastguard Worker * but when no inputs require grad, we should not create this node
1220*da0073e9SAndroid Build Coastguard Worker * - check_inplace logic
1221*da0073e9SAndroid Build Coastguard Worker * - view ops
1222*da0073e9SAndroid Build Coastguard Worker * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
1223*da0073e9SAndroid Build Coastguard Worker * test non-NDEBUG builds.
1224*da0073e9SAndroid Build Coastguard Worker * - tensorlist input and output
1225*da0073e9SAndroid Build Coastguard Worker * - multiple outputs / non-tensor output
1226*da0073e9SAndroid Build Coastguard Worker * - rebase_history vs set_history
1227*da0073e9SAndroid Build Coastguard Worker */
1228*da0073e9SAndroid Build Coastguard Worker namespace {
1229*da0073e9SAndroid Build Coastguard Worker
inplace_op(const torch::Tensor & self,const torch::Tensor & other)1230*da0073e9SAndroid Build Coastguard Worker torch::Tensor inplace_op(
1231*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1232*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1233*da0073e9SAndroid Build Coastguard Worker return self.add_(other);
1234*da0073e9SAndroid Build Coastguard Worker }
1235*da0073e9SAndroid Build Coastguard Worker
two_arg_inplace_op(const torch::Tensor & self,const torch::Tensor & other)1236*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
1237*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1238*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1239*da0073e9SAndroid Build Coastguard Worker other.add_(self);
1240*da0073e9SAndroid Build Coastguard Worker self.add_(other);
1241*da0073e9SAndroid Build Coastguard Worker return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1242*da0073e9SAndroid Build Coastguard Worker }
1243*da0073e9SAndroid Build Coastguard Worker
two_pairs_of_view_op(const torch::Tensor & self,const torch::Tensor & other)1244*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
1245*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1246*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1247*da0073e9SAndroid Build Coastguard Worker // This is not allowed. We test below that this calling into the boxed kernel
1248*da0073e9SAndroid Build Coastguard Worker // will raise an error
1249*da0073e9SAndroid Build Coastguard Worker return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1250*da0073e9SAndroid Build Coastguard Worker }
1251*da0073e9SAndroid Build Coastguard Worker
non_first_view_op(const torch::Tensor & self,const torch::Tensor & other)1252*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
1253*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1254*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1255*da0073e9SAndroid Build Coastguard Worker // This is not allowed. We test below that this calling into the boxed kernel
1256*da0073e9SAndroid Build Coastguard Worker // will raise an error
1257*da0073e9SAndroid Build Coastguard Worker return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
1258*da0073e9SAndroid Build Coastguard Worker }
1259*da0073e9SAndroid Build Coastguard Worker
ret_single_non_tensor(const torch::Tensor & self,const torch::Tensor & other)1260*da0073e9SAndroid Build Coastguard Worker int64_t ret_single_non_tensor(
1261*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1262*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1263*da0073e9SAndroid Build Coastguard Worker return 12;
1264*da0073e9SAndroid Build Coastguard Worker }
1265*da0073e9SAndroid Build Coastguard Worker
opt_op(const torch::Tensor & self,const std::optional<at::Tensor> & other)1266*da0073e9SAndroid Build Coastguard Worker torch::Tensor opt_op(
1267*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1268*da0073e9SAndroid Build Coastguard Worker const std::optional<at::Tensor>& other) {
1269*da0073e9SAndroid Build Coastguard Worker if (other.has_value()) {
1270*da0073e9SAndroid Build Coastguard Worker return self + other.value();
1271*da0073e9SAndroid Build Coastguard Worker } else {
1272*da0073e9SAndroid Build Coastguard Worker return self.clone();
1273*da0073e9SAndroid Build Coastguard Worker }
1274*da0073e9SAndroid Build Coastguard Worker }
1275*da0073e9SAndroid Build Coastguard Worker
my_custom_op(const torch::Tensor & self,const torch::Tensor & other)1276*da0073e9SAndroid Build Coastguard Worker torch::Tensor my_custom_op(
1277*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1278*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1279*da0073e9SAndroid Build Coastguard Worker return self + other;
1280*da0073e9SAndroid Build Coastguard Worker }
1281*da0073e9SAndroid Build Coastguard Worker
ret_tuple_non_tensor(const torch::Tensor & self,const torch::Tensor & other)1282*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
1283*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1284*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1285*da0073e9SAndroid Build Coastguard Worker auto a = self - other;
1286*da0073e9SAndroid Build Coastguard Worker auto b = self + other;
1287*da0073e9SAndroid Build Coastguard Worker return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
1288*da0073e9SAndroid Build Coastguard Worker }
1289*da0073e9SAndroid Build Coastguard Worker
view_op(const torch::Tensor & self)1290*da0073e9SAndroid Build Coastguard Worker torch::Tensor view_op(const torch::Tensor& self) {
1291*da0073e9SAndroid Build Coastguard Worker return self.alias();
1292*da0073e9SAndroid Build Coastguard Worker }
1293*da0073e9SAndroid Build Coastguard Worker
view_op_with_extra_arg(const torch::Tensor & self,const torch::Tensor & other)1294*da0073e9SAndroid Build Coastguard Worker torch::Tensor view_op_with_extra_arg(
1295*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1296*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1297*da0073e9SAndroid Build Coastguard Worker return self.alias();
1298*da0073e9SAndroid Build Coastguard Worker }
1299*da0073e9SAndroid Build Coastguard Worker
ret_tensor_vector_view(const torch::Tensor & self,const torch::Tensor & other)1300*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> ret_tensor_vector_view(
1301*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1302*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1303*da0073e9SAndroid Build Coastguard Worker return {self.alias(), self.alias()};
1304*da0073e9SAndroid Build Coastguard Worker }
1305*da0073e9SAndroid Build Coastguard Worker
ret_tensor_vector(const torch::Tensor & self,const torch::Tensor & other)1306*da0073e9SAndroid Build Coastguard Worker std::vector<at::Tensor> ret_tensor_vector(
1307*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& self,
1308*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& other) {
1309*da0073e9SAndroid Build Coastguard Worker std::vector<at::Tensor> out;
1310*da0073e9SAndroid Build Coastguard Worker out.push_back(self + other);
1311*da0073e9SAndroid Build Coastguard Worker out.push_back(self - other);
1312*da0073e9SAndroid Build Coastguard Worker return out;
1313*da0073e9SAndroid Build Coastguard Worker }
1314*da0073e9SAndroid Build Coastguard Worker
tensorlist_op(const torch::Tensor & self,at::TensorList other)1315*da0073e9SAndroid Build Coastguard Worker torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
1316*da0073e9SAndroid Build Coastguard Worker const auto& res = self.clone();
1317*da0073e9SAndroid Build Coastguard Worker for (const auto& t : other) {
1318*da0073e9SAndroid Build Coastguard Worker res.add_(t);
1319*da0073e9SAndroid Build Coastguard Worker }
1320*da0073e9SAndroid Build Coastguard Worker return res;
1321*da0073e9SAndroid Build Coastguard Worker }
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker #define REGISTER_TEST_OP(name, schema, fn) \
1324*da0073e9SAndroid Build Coastguard Worker auto m = MAKE_TORCH_LIBRARY(_test); \
1325*da0073e9SAndroid Build Coastguard Worker m.def(schema); \
1326*da0073e9SAndroid Build Coastguard Worker auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
1327*da0073e9SAndroid Build Coastguard Worker auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
1328*da0073e9SAndroid Build Coastguard Worker auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \
1329*da0073e9SAndroid Build Coastguard Worker m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
1330*da0073e9SAndroid Build Coastguard Worker m_autograd.impl( \
1331*da0073e9SAndroid Build Coastguard Worker name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
1332*da0073e9SAndroid Build Coastguard Worker m_inplaceorview.impl( \
1333*da0073e9SAndroid Build Coastguard Worker name, \
1334*da0073e9SAndroid Build Coastguard Worker c10::DispatchKey::ADInplaceOrView, \
1335*da0073e9SAndroid Build Coastguard Worker autogradNotImplementedInplaceOrViewFallback());
1336*da0073e9SAndroid Build Coastguard Worker
1337*da0073e9SAndroid Build Coastguard Worker template <typename F>
assertBasicChecks(F op)1338*da0073e9SAndroid Build Coastguard Worker void assertBasicChecks(F op) {
1339*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1340*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1341*da0073e9SAndroid Build Coastguard Worker auto c = torch::tensor({1.}, {torch::kFloat32});
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker // If any inputs require grad,
1344*da0073e9SAndroid Build Coastguard Worker auto out1 = op(a, b);
1345*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker // # Should not have grad_fn if none require grad
1348*da0073e9SAndroid Build Coastguard Worker auto out2 = op(b, c);
1349*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1350*da0073e9SAndroid Build Coastguard Worker out2.backward(),
1351*da0073e9SAndroid Build Coastguard Worker "element 0 of tensors does not require grad and does not have a grad_fn");
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker // TODO: Forward AD Tests?
1354*da0073e9SAndroid Build Coastguard Worker }
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker } // namespace
1357*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,RetSingleNonTensor)1358*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
1359*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1360*da0073e9SAndroid Build Coastguard Worker "ret_single_non_tensor",
1361*da0073e9SAndroid Build Coastguard Worker "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
1362*da0073e9SAndroid Build Coastguard Worker ret_single_non_tensor);
1363*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1364*da0073e9SAndroid Build Coastguard Worker "_test::ret_single_non_tensor", "");
1365*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1366*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
1367*da0073e9SAndroid Build Coastguard Worker opHandle, _1, _2);
1368*da0073e9SAndroid Build Coastguard Worker };
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1371*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
1374*da0073e9SAndroid Build Coastguard Worker }
1375*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,InplaceOp)1376*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, InplaceOp) {
1377*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1378*da0073e9SAndroid Build Coastguard Worker "inplace_op",
1379*da0073e9SAndroid Build Coastguard Worker "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
1380*da0073e9SAndroid Build Coastguard Worker inplace_op);
1381*da0073e9SAndroid Build Coastguard Worker auto opHandle =
1382*da0073e9SAndroid Build Coastguard Worker c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
1383*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1384*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1385*da0073e9SAndroid Build Coastguard Worker torch::Tensor,
1386*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1387*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1388*da0073e9SAndroid Build Coastguard Worker };
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1391*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1392*da0073e9SAndroid Build Coastguard Worker
1393*da0073e9SAndroid Build Coastguard Worker // Check in-place
1394*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1395*da0073e9SAndroid Build Coastguard Worker op(a, b),
1396*da0073e9SAndroid Build Coastguard Worker "a leaf Variable that requires grad is being used in an in-place operation");
1397*da0073e9SAndroid Build Coastguard Worker op(b, a);
1398*da0073e9SAndroid Build Coastguard Worker a = a.clone();
1399*da0073e9SAndroid Build Coastguard Worker b = b.clone();
1400*da0073e9SAndroid Build Coastguard Worker auto c = op(a, b);
1401*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker // Test in-place on view
1404*da0073e9SAndroid Build Coastguard Worker auto base =
1405*da0073e9SAndroid Build Coastguard Worker torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1406*da0073e9SAndroid Build Coastguard Worker auto view = base.view(-1);
1407*da0073e9SAndroid Build Coastguard Worker auto t = torch::tensor({1.}, {torch::kFloat32});
1408*da0073e9SAndroid Build Coastguard Worker
1409*da0073e9SAndroid Build Coastguard Worker torch::Tensor v_nograd;
1410*da0073e9SAndroid Build Coastguard Worker {
1411*da0073e9SAndroid Build Coastguard Worker c10::NoGradGuard guard;
1412*da0073e9SAndroid Build Coastguard Worker v_nograd = base.view(-1);
1413*da0073e9SAndroid Build Coastguard Worker op(v_nograd, t);
1414*da0073e9SAndroid Build Coastguard Worker }
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
1417*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
1418*da0073e9SAndroid Build Coastguard Worker ASSERT_THAT(
1419*da0073e9SAndroid Build Coastguard Worker op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
1420*da0073e9SAndroid Build Coastguard Worker }
1421*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,DoubleInplaceOp)1422*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
1423*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1424*da0073e9SAndroid Build Coastguard Worker "two_arg_inplace_op",
1425*da0073e9SAndroid Build Coastguard Worker "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
1426*da0073e9SAndroid Build Coastguard Worker two_arg_inplace_op);
1427*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1428*da0073e9SAndroid Build Coastguard Worker "_test::two_arg_inplace_op", "");
1429*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1430*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1431*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor>,
1432*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1433*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1434*da0073e9SAndroid Build Coastguard Worker };
1435*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1436*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker // Both are modified in-place!
1439*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1440*da0073e9SAndroid Build Coastguard Worker op(a, b),
1441*da0073e9SAndroid Build Coastguard Worker "a leaf Variable that requires grad is being used in an in-place operation");
1442*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1443*da0073e9SAndroid Build Coastguard Worker op(b, a),
1444*da0073e9SAndroid Build Coastguard Worker "a leaf Variable that requires grad is being used in an in-place operation");
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker auto c =
1447*da0073e9SAndroid Build Coastguard Worker torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1448*da0073e9SAndroid Build Coastguard Worker auto d =
1449*da0073e9SAndroid Build Coastguard Worker torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1450*da0073e9SAndroid Build Coastguard Worker
1451*da0073e9SAndroid Build Coastguard Worker auto saved_version_c = c._version();
1452*da0073e9SAndroid Build Coastguard Worker auto saved_version_d = d._version();
1453*da0073e9SAndroid Build Coastguard Worker op(c, d);
1454*da0073e9SAndroid Build Coastguard Worker ASSERT_NE(c._version(), saved_version_c);
1455*da0073e9SAndroid Build Coastguard Worker ASSERT_NE(d._version(), saved_version_d);
1456*da0073e9SAndroid Build Coastguard Worker }
1457*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,OptOp)1458*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, OptOp) {
1459*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1460*da0073e9SAndroid Build Coastguard Worker "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
1461*da0073e9SAndroid Build Coastguard Worker auto opHandle =
1462*da0073e9SAndroid Build Coastguard Worker c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
1463*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1,
1464*da0073e9SAndroid Build Coastguard Worker const std::optional<torch::Tensor>& _2) {
1465*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1466*da0073e9SAndroid Build Coastguard Worker torch::Tensor,
1467*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1468*da0073e9SAndroid Build Coastguard Worker const std::optional<torch::Tensor>&>(opHandle, _1, _2);
1469*da0073e9SAndroid Build Coastguard Worker };
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1472*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1473*da0073e9SAndroid Build Coastguard Worker
1474*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
1475*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
1476*da0073e9SAndroid Build Coastguard Worker }
1477*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,OutOfPlaceAddition)1478*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
1479*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1480*da0073e9SAndroid Build Coastguard Worker "my_custom_op",
1481*da0073e9SAndroid Build Coastguard Worker "_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
1482*da0073e9SAndroid Build Coastguard Worker my_custom_op);
1483*da0073e9SAndroid Build Coastguard Worker auto opHandle =
1484*da0073e9SAndroid Build Coastguard Worker c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
1485*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1486*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1487*da0073e9SAndroid Build Coastguard Worker torch::Tensor,
1488*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1489*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1490*da0073e9SAndroid Build Coastguard Worker };
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker assertBasicChecks(op);
1493*da0073e9SAndroid Build Coastguard Worker }
1494*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,RetTupleNonTensor)1495*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
1496*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1497*da0073e9SAndroid Build Coastguard Worker "ret_tuple_non_tensor",
1498*da0073e9SAndroid Build Coastguard Worker "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
1499*da0073e9SAndroid Build Coastguard Worker ret_tuple_non_tensor);
1500*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1501*da0073e9SAndroid Build Coastguard Worker "_test::ret_tuple_non_tensor", "");
1502*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1503*da0073e9SAndroid Build Coastguard Worker auto out = callOpUnboxed<
1504*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor, int64_t>,
1505*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1506*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1507*da0073e9SAndroid Build Coastguard Worker auto [out0, out1, out2] = std::move(out);
1508*da0073e9SAndroid Build Coastguard Worker return out0;
1509*da0073e9SAndroid Build Coastguard Worker };
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker assertBasicChecks(op);
1512*da0073e9SAndroid Build Coastguard Worker }
1513*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,ViewOp)1514*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, ViewOp) {
1515*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1516*da0073e9SAndroid Build Coastguard Worker "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
1517*da0073e9SAndroid Build Coastguard Worker auto opHandle =
1518*da0073e9SAndroid Build Coastguard Worker c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
1519*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1) {
1520*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
1521*da0073e9SAndroid Build Coastguard Worker };
1522*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1523*da0073e9SAndroid Build Coastguard Worker auto v = op(b);
1524*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(v.is_view());
1525*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker auto b1 =
1528*da0073e9SAndroid Build Coastguard Worker torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1529*da0073e9SAndroid Build Coastguard Worker auto v1 = op(b1);
1530*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(v1.is_view());
1531*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
1532*da0073e9SAndroid Build Coastguard Worker
1533*da0073e9SAndroid Build Coastguard Worker // Test inplace on view
1534*da0073e9SAndroid Build Coastguard Worker auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker // raise on rebase_history when it refreshes grad_fn
1537*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1538*da0073e9SAndroid Build Coastguard Worker v1.add_(t), "which does not have a derivative implemented is forbidden");
1539*da0073e9SAndroid Build Coastguard Worker // base should not be aware of the views, so this is still okay
1540*da0073e9SAndroid Build Coastguard Worker b1.add_(t);
1541*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1542*da0073e9SAndroid Build Coastguard Worker v1.grad_fn(),
1543*da0073e9SAndroid Build Coastguard Worker "which does not have a derivative implemented is forbidden");
1544*da0073e9SAndroid Build Coastguard Worker }
1545*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,ViewOpWithExtraArg)1546*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
1547*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1548*da0073e9SAndroid Build Coastguard Worker "view_op_with_extra_arg",
1549*da0073e9SAndroid Build Coastguard Worker "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
1550*da0073e9SAndroid Build Coastguard Worker view_op_with_extra_arg);
1551*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1552*da0073e9SAndroid Build Coastguard Worker "_test::view_op_with_extra_arg", "");
1553*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1554*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1555*da0073e9SAndroid Build Coastguard Worker torch::Tensor,
1556*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1557*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1558*da0073e9SAndroid Build Coastguard Worker };
1559*da0073e9SAndroid Build Coastguard Worker assertBasicChecks(op);
1560*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32});
1561*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({2.}, {torch::kFloat32});
1562*da0073e9SAndroid Build Coastguard Worker auto out1 = op(a, b);
1563*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(out1.is_view());
1564*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1565*da0073e9SAndroid Build Coastguard Worker }
1566*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,RetTensorVectorView)1567*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
1568*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1569*da0073e9SAndroid Build Coastguard Worker "ret_tensor_vector_view",
1570*da0073e9SAndroid Build Coastguard Worker "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
1571*da0073e9SAndroid Build Coastguard Worker ret_tensor_vector_view);
1572*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1573*da0073e9SAndroid Build Coastguard Worker "_test::ret_tensor_vector_view", "");
1574*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1575*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1576*da0073e9SAndroid Build Coastguard Worker std::vector<at::Tensor>,
1577*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1578*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1579*da0073e9SAndroid Build Coastguard Worker };
1580*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32});
1581*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1582*da0073e9SAndroid Build Coastguard Worker auto out = op(a, b);
1583*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(out[0].is_view());
1584*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1585*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(out[1].is_view());
1586*da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1587*da0073e9SAndroid Build Coastguard Worker }
1588*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,DoubleViewOP)1589*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
1590*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1591*da0073e9SAndroid Build Coastguard Worker "two_pairs_of_view_op",
1592*da0073e9SAndroid Build Coastguard Worker "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
1593*da0073e9SAndroid Build Coastguard Worker two_pairs_of_view_op);
1594*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1595*da0073e9SAndroid Build Coastguard Worker "_test::two_pairs_of_view_op", "");
1596*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1597*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1598*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor>,
1599*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1600*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1601*da0073e9SAndroid Build Coastguard Worker };
1602*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1603*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1604*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1605*da0073e9SAndroid Build Coastguard Worker op(a, b),
1606*da0073e9SAndroid Build Coastguard Worker "Expected only a single output in the operator schema to have a non-write alias annotation");
1607*da0073e9SAndroid Build Coastguard Worker }
1608*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,NonFirstViewOP)1609*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
1610*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1611*da0073e9SAndroid Build Coastguard Worker "non_first_view_op",
1612*da0073e9SAndroid Build Coastguard Worker "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
1613*da0073e9SAndroid Build Coastguard Worker non_first_view_op);
1614*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1615*da0073e9SAndroid Build Coastguard Worker "_test::non_first_view_op", "");
1616*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1617*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1618*da0073e9SAndroid Build Coastguard Worker std::tuple<torch::Tensor, torch::Tensor>,
1619*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1620*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2);
1621*da0073e9SAndroid Build Coastguard Worker };
1622*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1623*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1624*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1625*da0073e9SAndroid Build Coastguard Worker op(a, b), "can only create view relationships between the first");
1626*da0073e9SAndroid Build Coastguard Worker }
1627*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,RetTensorVector)1628*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
1629*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1630*da0073e9SAndroid Build Coastguard Worker "ret_tensor_vector",
1631*da0073e9SAndroid Build Coastguard Worker "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
1632*da0073e9SAndroid Build Coastguard Worker ret_tensor_vector);
1633*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1634*da0073e9SAndroid Build Coastguard Worker "_test::ret_tensor_vector", "");
1635*da0073e9SAndroid Build Coastguard Worker auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1636*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<
1637*da0073e9SAndroid Build Coastguard Worker std::vector<at::Tensor>,
1638*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&,
1639*da0073e9SAndroid Build Coastguard Worker const torch::Tensor&>(opHandle, _1, _2)[0];
1640*da0073e9SAndroid Build Coastguard Worker };
1641*da0073e9SAndroid Build Coastguard Worker assertBasicChecks(op);
1642*da0073e9SAndroid Build Coastguard Worker }
1643*da0073e9SAndroid Build Coastguard Worker
TEST(TestAutogradNotImplementedFallback,TensorlistOp)1644*da0073e9SAndroid Build Coastguard Worker TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
1645*da0073e9SAndroid Build Coastguard Worker REGISTER_TEST_OP(
1646*da0073e9SAndroid Build Coastguard Worker "tensorlist_op",
1647*da0073e9SAndroid Build Coastguard Worker "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
1648*da0073e9SAndroid Build Coastguard Worker tensorlist_op);
1649*da0073e9SAndroid Build Coastguard Worker auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1650*da0073e9SAndroid Build Coastguard Worker "_test::tensorlist_op", "");
1651*da0073e9SAndroid Build Coastguard Worker auto op = [&](torch::Tensor _1, at::TensorList _2) {
1652*da0073e9SAndroid Build Coastguard Worker return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
1653*da0073e9SAndroid Build Coastguard Worker opHandle, _1, _2);
1654*da0073e9SAndroid Build Coastguard Worker };
1655*da0073e9SAndroid Build Coastguard Worker
1656*da0073e9SAndroid Build Coastguard Worker auto a = torch::tensor({1.}, {torch::kFloat32});
1657*da0073e9SAndroid Build Coastguard Worker auto b = torch::tensor({1.}, {torch::kFloat32});
1658*da0073e9SAndroid Build Coastguard Worker auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1659*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> vec = {b, c};
1660*da0073e9SAndroid Build Coastguard Worker auto out = op(a, vec);
1661*da0073e9SAndroid Build Coastguard Worker
1662*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1663*da0073e9SAndroid Build Coastguard Worker torch::autograd::grad({out}, {vec[0]}),
1664*da0073e9SAndroid Build Coastguard Worker "element 0 of the input tensors does not require grad");
1665*da0073e9SAndroid Build Coastguard Worker ASSERT_THROWS_WITH(
1666*da0073e9SAndroid Build Coastguard Worker torch::autograd::grad({out}, {vec[1]}), "is not implemented");
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
1669*da0073e9SAndroid Build Coastguard Worker }
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker // TODO add these tests if needed
1672*da0073e9SAndroid Build Coastguard Worker // test_once_differentiable
1673*da0073e9SAndroid Build Coastguard Worker // test_sparse_backward
1674*da0073e9SAndroid Build Coastguard Worker // test_save_output_nr
1675*da0073e9SAndroid Build Coastguard Worker // test_free_deep_graph_pyfunction
1676*da0073e9SAndroid Build Coastguard Worker // test_naughty_anomaly_access
1677*da0073e9SAndroid Build Coastguard Worker // test_naughty_autograd-function_stashing_ctx
1678*da0073e9SAndroid Build Coastguard Worker // test_custom_autograd_repeated_grad_grad
1679*da0073e9SAndroid Build Coastguard Worker // test_return_leaf
1680*da0073e9SAndroid Build Coastguard Worker // test_anomaly_detect_nan
1681*da0073e9SAndroid Build Coastguard Worker // test_no_grad_copy
1682