xref: /aosp_15_r20/external/pytorch/test/cpp/api/autograd.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/boxing/impl/test_helpers.h>
2 #include <gtest/gtest.h>
3 
4 #include <ATen/core/op_registration/op_registration.h>
5 #include <torch/torch.h>
6 
7 #include <torch/csrc/autograd/FunctionsManual.h>
8 #include <torch/csrc/autograd/functions/basic_ops.h>
9 
10 #include <test/cpp/api/support.h>
11 
12 using namespace torch::autograd;
13 using namespace torch::test;
14 
15 #define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
16 #define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
17 
graph_desc(std::shared_ptr<Node> node)18 std::string graph_desc(std::shared_ptr<Node> node) {
19   if (!node) {
20     return "None";
21   }
22   auto result = node->name() + "(";
23   auto next_edges = node->next_edges();
24   for (auto& edge : next_edges) {
25     result += graph_desc(edge.function);
26   }
27   return result + ")";
28 }
29 
simple_fn(const Variable & x,const Variable & y)30 Variable simple_fn(const Variable& x, const Variable& y) {
31   return x + 2 * y + x * y;
32 }
33 
TEST(AutogradAPITests,RegisterHookVoidReturnAcceptsUndefinedTensor)34 TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) {
35   auto x = at::zeros({}, at::kCPU);
36   x.requires_grad_();
37   x.register_hook([](at::TensorBase x) { return; });
38   auto y = torch::autograd::UndefinedGrad().apply({x});
39   y[0].backward();
40 }
41 
TEST(AutogradAPITests,RegisterHookTensorReturnAcceptsUndefinedTensor)42 TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) {
43   auto x = at::zeros({}, at::kCPU);
44   x.requires_grad_();
45   x.register_hook([](at::Tensor x) -> at::Tensor { return x; });
46   auto y = torch::autograd::UndefinedGrad().apply({x});
47   y[0].backward();
48 }
49 
TEST(AutogradAPITests,BackwardSimpleTest)50 TEST(AutogradAPITests, BackwardSimpleTest) {
51   Variable x = torch::randn({2, 2}, torch::requires_grad());
52   Variable y = torch::randn({2, 2}, torch::requires_grad());
53   auto res = simple_fn(x, y);
54   backward({res.sum()}, {});
55 
56   ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
57   ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
58 }
59 
TEST(AutogradAPITests,BackwardTest)60 TEST(AutogradAPITests, BackwardTest) {
61   Variable x = torch::randn({2, 2}, torch::requires_grad());
62   Variable y = torch::randn({2, 2}, torch::requires_grad());
63   auto res = simple_fn(x, y);
64   backward({res}, {torch::ones({2, 2})}, {}, true);
65 
66   backward({res}, {torch::ones({2, 2})});
67 
68   ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
69   ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
70 }
71 
TEST(AutogradAPITests,GradSimpleTest)72 TEST(AutogradAPITests, GradSimpleTest) {
73   // basic grad
74   Variable x = torch::randn({2, 2}, torch::requires_grad());
75   Variable y = torch::randn({2, 2}, torch::requires_grad());
76   auto res = simple_fn(x, y);
77   auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
78 
79   ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
80   ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
81 }
82 
TEST(AutogradAPITests,GradTest)83 TEST(AutogradAPITests, GradTest) {
84   Variable x = torch::randn({2, 2}, torch::requires_grad());
85   Variable y = torch::randn({2, 2}, torch::requires_grad());
86   auto res = simple_fn(x, y);
87   res.backward(torch::ones({2, 2}), false, true);
88 
89   Variable x_grad = y + torch::ones({2, 2});
90   Variable y_grad = x + torch::ones({2, 2}) * 2;
91   ASSERT_VARIABLE_EQ(x.grad(), x_grad);
92   ASSERT_VARIABLE_EQ(y.grad(), y_grad);
93 
94   Variable grad_sum = 2 * x.grad() + y.grad();
95   auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
96 
97   ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
98   ASSERT_VARIABLE_EQ(x.grad(), x_grad);
99   ASSERT_VARIABLE_EQ(y.grad(), y_grad);
100 }
101 
TEST(AutogradAPITests,GradNonLeafTest)102 TEST(AutogradAPITests, GradNonLeafTest) {
103   Variable x_init = torch::randn({2, 2}, torch::requires_grad());
104   Variable x = x_init;
105   Variable y = torch::randn({2, 2}, torch::requires_grad());
106   Variable grad_output = torch::ones({2, 2});
107 
108   for (int i = 0; i < 5; ++i) {
109     auto res = simple_fn(x, y);
110     auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
111 
112     Variable grad_x_expected = y + torch::ones({2, 2});
113     ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
114     ASSERT_FALSE(x.grad().defined());
115     ASSERT_FALSE(y.grad().defined());
116     x = x + 0.05 * input_grads[0];
117   }
118 
119   float val_init = simple_fn(x_init, y).sum().item().toFloat();
120   float val_final = simple_fn(x, y).sum().item().toFloat();
121   ASSERT_TRUE(val_final > val_init);
122 
123   x.backward(grad_output, false, true);
124   ASSERT_TRUE(x_init.grad().defined());
125   ASSERT_TRUE(y.grad().defined());
126 }
127 
TEST(AutogradAPITests,GradUnreachableTest)128 TEST(AutogradAPITests, GradUnreachableTest) {
129   Variable x = torch::ones({1}, torch::requires_grad());
130   Variable y = torch::ones({1}, torch::requires_grad());
131 
132   Variable z = x * 2;
133   Variable w = y * 2;
134 
135   auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
136   ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
137   ASSERT_FALSE(grad_res[1].defined());
138 
139   // This is slightly different than the case above, because z doesn't even
140   // have a grad accumulator allocated.
141   z = torch::ones({1}, torch::requires_grad());
142   grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
143 
144   ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
145   ASSERT_FALSE(grad_res[1].defined());
146 
147   // allow_unused=False, but grads contains None inside, should throw
148   ASSERT_THROWS_WITH(
149       grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
150 }
151 
TEST(CustomAutogradTest,GradUnreachableDiscoveryTest)152 TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
153   // Test that certain nodes are not erroneously executed when an input
154   // is unreachable. See #39784
155   struct MyFunction : public Function<MyFunction> {
156     static Variable forward(AutogradContext* ctx, Variable var) {
157       return var;
158     }
159 
160     static variable_list backward(
161         AutogradContext* ctx,
162         variable_list grad_output) {
163       ADD_FAILURE() << "This node should not be executed!";
164       return grad_output;
165     }
166   };
167 
168   auto x = torch::randn(1, torch::requires_grad());
169   auto x1 = torch::randn(1);
170   auto x2 = MyFunction::apply(x + x1);
171 
172   auto y = torch::randn(1, torch::requires_grad());
173   auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
174   ASSERT_FALSE(grad_res[0].defined());
175 }
176 
TEST(AutogradAPITests,EmptyInput)177 TEST(AutogradAPITests, EmptyInput) {
178   Variable x = torch::ones({1}, torch::requires_grad());
179   ASSERT_THROWS_WITH(
180       grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
181 }
182 
TEST(AutogradAPITests,RetainGrad)183 TEST(AutogradAPITests, RetainGrad) {
184   auto input = torch::rand({1, 3}, torch::requires_grad());
185   auto h1 = input * 3;
186   auto out = (h1 * h1).sum();
187 
188   {
189     // Warning when grad is accessed for non-leaf tensor
190     WarningCapture warnings;
191     ASSERT_FALSE(h1.grad().defined());
192     ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
193   }
194   // It should be possible to call retain_grad() multiple times
195   h1.retain_grad();
196   h1.retain_grad();
197   {
198     // If retain_grad is true for a non-leaf tensor,
199     // there should not be any warning when grad is accessed
200     WarningCapture warnings;
201     ASSERT_FALSE(h1.grad().defined());
202     ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
203   }
204 
205   // Gradient should be accumulated
206   // NOLINTNEXTLINE(bugprone-argument-comment)
207   out.backward({}, /*keep_graph=*/true);
208   ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
209   // NOLINTNEXTLINE(bugprone-argument-comment)
210   out.backward({}, /*keep_graph=*/true);
211   ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
212 
213   {
214     torch::NoGradGuard no_grad;
215     input.grad().zero_();
216   }
217   // It should be a no-op for leaves
218   input.retain_grad();
219   input.retain_grad();
220   out.backward();
221   ASSERT_VARIABLE_EQ(input * 18, input.grad());
222 }
223 
TEST(AutogradAPITests,AnomalyMode)224 TEST(AutogradAPITests, AnomalyMode) {
225   // Needs to have backtrace as warning and then throw an error
226   torch::autograd::DetectAnomalyGuard detect_anomaly;
227   {
228     WarningCapture warnings;
229     auto x = torch::tensor({5.0}, torch::requires_grad());
230     auto y = x * x;
231     auto z = y * y;
232     y += 1;
233     ASSERT_THROWS_WITH(z.backward(), "inplace");
234     ASSERT_TRUE(
235         warnings.str().find("Traceback of forward") != std::string::npos);
236   }
237   auto double_backward_produce_nan = [](bool should_throw) {
238     auto x = torch::tensor({0.0}, torch::requires_grad());
239     auto y = x.pow(1.5);
240     auto gr =
241         // NOLINTNEXTLINE(bugprone-argument-comment)
242         grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
243     if (should_throw) {
244       WarningCapture warnings;
245       ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
246                          , "returned nan");
247       auto msgs = warnings.messages();
248       ASSERT_EQ(msgs.size(), 2);
249       ASSERT_TRUE(
250           msgs[0].find("Traceback of forward call that caused the error") !=
251           std::string::npos);
252       ASSERT_TRUE(
253           msgs[1].find(
254               "Traceback of forward call that induced the previous calculation") !=
255           std::string::npos);
256     } else {
257       grad({gr[0]}, {x}, {torch::tensor({0.0})});
258     }
259   };
260 
261   double_backward_produce_nan(true);
262   {
263     torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
264     double_backward_produce_nan(false);
265     {
266       torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
267       double_backward_produce_nan(true);
268     }
269   }
270   double_backward_produce_nan(true);
271 }
272 
TEST(CustomAutogradTest,CustomFunctionReturnInputAsIsAndSavesIt)273 TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
274   struct MyFunction : public Function<MyFunction> {
275     static Variable forward(
276         AutogradContext* ctx,
277         Variable var1,
278         Variable var2) {
279       ctx->save_for_backward({var1, var2});
280       return var1 * var2, var1;
281     }
282 
283     static variable_list backward(
284         AutogradContext* ctx,
285         variable_list grad_output) {
286       return {};
287     }
288   };
289 
290   Variable x = torch::randn({5, 5}, torch::requires_grad());
291   Variable y = torch::randn({5, 5}, torch::requires_grad());
292   MyFunction::apply(x, y);
293 }
294 
TEST(CustomAutogradTest,CustomFunction)295 TEST(CustomAutogradTest, CustomFunction) {
296   struct MyFunction : public Function<MyFunction> {
297     static Variable forward(
298         AutogradContext* ctx,
299         Variable var1,
300         int mul,
301         Variable var2) {
302       ctx->saved_data["mul"] = mul;
303       ctx->save_for_backward({var1, var2});
304       return var1 + mul * var2 + var1 * var2;
305     }
306 
307     static variable_list backward(
308         AutogradContext* ctx,
309         variable_list grad_output) {
310       int mul = ctx->saved_data["mul"].toInt();
311       auto saved = ctx->get_saved_variables();
312       auto var1 = saved[0];
313       auto var2 = saved[1];
314       variable_list output = {
315           grad_output[0] + grad_output[0] * var2,
316           Variable(),
317           grad_output[0] * mul + grad_output[0] * var1};
318       return output;
319     }
320   };
321 
322   Variable x = torch::randn({5, 5}, torch::requires_grad());
323   Variable y = torch::randn({5, 5}, torch::requires_grad());
324   auto res = MyFunction::apply(x, 2, y);
325   auto go = torch::ones({}, torch::requires_grad());
326   res.sum().backward(go, false, true);
327 
328   ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
329   ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
330 }
331 
TEST(CustomAutogradTest,CustomFunctionWithTensorList)332 TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
333   struct MyFunction : public Function<MyFunction> {
334     static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
335       torch::autograd::variable_list vars;
336       for (const at::Tensor& tensor : tensors) {
337         vars.push_back(tensor);
338       }
339       ctx->save_for_backward(vars);
340       return tensors[0] + tensors[1] + tensors[0] * tensors[1];
341     }
342 
343     static variable_list backward(
344         AutogradContext* ctx,
345         variable_list grad_output) {
346       auto saved = ctx->get_saved_variables();
347       auto var1 = saved[0];
348       auto var2 = saved[1];
349       variable_list output = {
350           grad_output[0] + grad_output[0] * var2,
351           grad_output[0] + grad_output[0] * var1};
352       return output;
353     }
354   };
355 
356   at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
357   at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
358   torch::autograd::variable_list variables = {x, y};
359   at::TensorList tensors = variables;
360   auto res = MyFunction::apply(tensors);
361   auto go = torch::ones({}, torch::requires_grad());
362   res.sum().backward(go, false, true);
363 
364   ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
365   ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
366 }
367 
TEST(CustomAutogradTest,GraphTaskTrimEdges)368 TEST(CustomAutogradTest, GraphTaskTrimEdges) {
369   struct MyFunction : public Function<MyFunction> {
370     static Variable forward(
371         AutogradContext* ctx,
372         Variable var1,
373         Variable var2,
374         int mul,
375         bool needs_input1_grad,
376         bool needs_input2_grad) {
377       // setup the expected should and should not compute idx
378       ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
379       ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
380 
381       ctx->saved_data["mul"] = mul;
382       ctx->save_for_backward({var1, var2});
383       return var1 + mul * var2 + var1 * var2;
384     }
385 
386     static variable_list backward(
387         AutogradContext* ctx,
388         variable_list grad_output) {
389       // Test `needs_input_grad` method is working correctly.
390       // We have to test this within the backward function.
391       auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
392       auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
393       IndexRange var1_idx = {0, 1};
394       IndexRange var2_idx = {1, 2};
395       EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
396       EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
397       EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
398       EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
399       EXPECT_EQ(
400           ctx->needs_input_grad({var1_idx, var2_idx}),
401           needs_input1_grad || needs_input2_grad);
402 
403       // calculate gradients
404       int mul = ctx->saved_data["mul"].toInt();
405       auto saved = ctx->get_saved_variables();
406       auto var1 = saved[0];
407       auto var2 = saved[1];
408 
409       Variable grad_var1, grad_var2;
410       if (ctx->needs_input_grad(0)) {
411         grad_var1 = grad_output[0] + grad_output[0] * var2;
412       }
413       if (ctx->needs_input_grad(1)) {
414         grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
415       }
416       variable_list output = {
417           grad_var1,
418           grad_var2,
419           Variable(),
420           Variable(),
421           Variable(),
422       };
423       return output;
424     }
425   };
426 
427   Variable x = torch::randn({5, 5}, torch::requires_grad());
428   Variable y = torch::randn({5, 5}, torch::requires_grad());
429   auto go = torch::ones_like(x);
430   Variable out;
431 
432   // grad_x
433   out = MyFunction::apply(
434       x,
435       y,
436       2,
437       /* needs_input1_grad= */ true,
438       /* needs_input2_grad= */ false);
439   auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
440   ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
441 
442   // grad_y
443   out = MyFunction::apply(
444       x,
445       y,
446       2,
447       /* needs_input1_grad= */ false,
448       /* needs_input2_grad= */ true);
449   auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
450   ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
451 
452   // grad_x and grad_y
453   out = MyFunction::apply(
454       x,
455       y,
456       2,
457       /* needs_input1_grad= */ true,
458       /* needs_input2_grad= */ true);
459   auto grads = torch::autograd::grad({out}, {x, y}, {go});
460   ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
461   ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
462 }
463 
TEST(CustomAutogradTest,FunctionReturnsInput)464 TEST(CustomAutogradTest, FunctionReturnsInput) {
465   struct MyFunction : public Function<MyFunction> {
466     static Variable forward(AutogradContext* ctx, Variable var1) {
467       return var1;
468     }
469 
470     static variable_list backward(
471         AutogradContext* ctx,
472         variable_list grad_output) {
473       return {grad_output[0] * 2};
474     }
475   };
476 
477   Variable x(torch::ones(1, torch::requires_grad()));
478   MyFunction::apply(x).backward(torch::ones(1), true, true);
479   ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
480 }
481 
TEST(CustomAutogradTest,FunctionReturnsUndefined)482 TEST(CustomAutogradTest, FunctionReturnsUndefined) {
483   struct MyFunction : public Function<MyFunction> {
484     static Variable forward(AutogradContext* ctx, Variable var) {
485       return var * 2;
486     }
487 
488     static variable_list backward(
489         AutogradContext* ctx,
490         variable_list grad_output) {
491       at::Tensor undefined_tensor;
492       return {undefined_tensor};
493     }
494   };
495 
496   auto x = torch::ones(1, torch::requires_grad());
497 
498   MyFunction::apply(x).backward();
499   ASSERT_FALSE(x.grad().defined());
500 
501   MyFunction::apply(x.pow(2)).backward();
502   ASSERT_FALSE(x.grad().defined());
503 
504   MyFunction::apply(x).sum().backward();
505   ASSERT_FALSE(x.grad().defined());
506 
507   ASSERT_FALSE(torch::autograd::grad(
508                    {MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
509                    .defined());
510 }
511 
TEST(CustomAutogradTest,MaterializeGrads)512 TEST(CustomAutogradTest, MaterializeGrads) {
513   struct MyFunction : public Function<MyFunction> {
514     static Variable forward(AutogradContext* ctx, Variable var) {
515       return var;
516     }
517 
518     static variable_list backward(
519         AutogradContext* ctx,
520         variable_list grad_output) {
521       EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
522       return grad_output;
523     }
524   };
525 
526   auto x = torch::ones(1, torch::requires_grad());
527   UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
528 }
529 
TEST(CustomAutogradTest,DontMaterializeGrads)530 TEST(CustomAutogradTest, DontMaterializeGrads) {
531   struct MyFunction : public Function<MyFunction> {
532     static Variable forward(AutogradContext* ctx, Variable var) {
533       ctx->set_materialize_grads(false);
534       return var;
535     }
536 
537     static variable_list backward(
538         AutogradContext* ctx,
539         variable_list grad_output) {
540       EXPECT_FALSE(grad_output[0].defined());
541       return grad_output;
542     }
543   };
544 
545   auto x = torch::ones(1, torch::requires_grad());
546   UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
547 }
548 
TEST(CustomAutogradTest,NoGradCustomFunction)549 TEST(CustomAutogradTest, NoGradCustomFunction) {
550   // Custom Function should respect grad mode
551   struct MyOp : public Function<MyOp> {
552     static Variable forward(AutogradContext* ctx, Variable x) {
553       return x + 1;
554     }
555 
556     static variable_list backward(AutogradContext* ctx, variable_list dy) {
557       return dy;
558     }
559   };
560 
561   auto x = torch::ones({5, 5}, torch::requires_grad());
562   {
563     at::NoGradGuard no_grad;
564     auto y = MyOp::apply(x);
565     ASSERT_FALSE(y.requires_grad());
566   }
567 }
568 
TEST(CustomAutogradTest,MarkDirty)569 TEST(CustomAutogradTest, MarkDirty) {
570   struct MyFunction : public Function<MyFunction> {
571     static Variable forward(AutogradContext* ctx, Variable v) {
572       // Change the value inplace
573       auto v_data = v.data_ptr<float>();
574       v_data[0] = 2;
575       ctx->mark_dirty({v});
576       return v;
577     }
578 
579     static variable_list backward(
580         AutogradContext* ctx,
581         variable_list grad_output) {
582       return {(grad_output[0] * 2.0)};
583     }
584   };
585 
586   // Clone here because modifying leafs inplace is not allowed
587   auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
588   auto version_before = x._version();
589   auto out = MyFunction::apply(x);
590   auto version_after = x._version();
591   ASSERT_TRUE(version_after >= (version_before + 1));
592   out.sum().backward();
593 }
594 
TEST(CustomAutogradTest,MarkNonDifferentiable)595 TEST(CustomAutogradTest, MarkNonDifferentiable) {
596   struct MyFunction : public Function<MyFunction> {
597     static Variable forward(AutogradContext* ctx, Variable v) {
598       Variable output = v > 0;
599       ctx->mark_non_differentiable({output});
600       return output;
601     }
602 
603     static variable_list backward(
604         AutogradContext* ctx,
605         variable_list grad_output) {
606       return {(grad_output[0] * 0.0)};
607     }
608   };
609 
610   auto x = torch::randn({5, 5}, torch::requires_grad());
611   auto mask = MyFunction::apply(x);
612   ASSERT_FALSE(mask.requires_grad());
613   auto y = x.masked_fill(mask, 0);
614   y.sum().backward();
615 }
616 
TEST(CustomAutogradTest,MarkNonDifferentiableMixed)617 TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
618   struct MyFunction : public Function<MyFunction> {
619     static variable_list forward(AutogradContext* ctx, Variable input) {
620       Variable a = input + 1;
621       Variable b = input + 2;
622       ctx->mark_non_differentiable({a});
623       return {a, b};
624     }
625 
626     static variable_list backward(
627         AutogradContext* ctx,
628         variable_list grad_output) {
629       const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
630       EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
631       EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
632       return {grad_b};
633     }
634   };
635 
636   auto x = torch::randn({5, 5}, torch::requires_grad());
637   auto out = MyFunction::apply(x);
638 
639   ASSERT_FALSE(out[0].requires_grad());
640   ASSERT_TRUE(out[1].requires_grad());
641   out[1].sum().backward();
642   ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
643 }
644 
TEST(CustomAutogradTest,MarkNonDifferentiableNone)645 TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
646   struct MyFunction : public Function<MyFunction> {
647     static Variable forward(AutogradContext* ctx, Variable input) {
648       auto output = input.clone();
649       ctx->mark_non_differentiable({output});
650       return output;
651     }
652 
653     static variable_list backward(
654         AutogradContext* ctx,
655         variable_list grad_outputs) {
656       return {};
657     }
658   };
659 
660   auto x = torch::randn({5, 5}, torch::requires_grad());
661   auto r = MyFunction::apply(x * x);
662   (r * x).sum().backward();
663 }
664 
TEST(CustomAutogradTest,ReturnLeafInplace)665 TEST(CustomAutogradTest, ReturnLeafInplace) {
666   struct Inplace : public Function<Inplace> {
667     static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
668       ctx->mark_dirty({a});
669       return {a.add_(b), b + 2};
670     }
671 
672     static variable_list backward(
673         AutogradContext* ctx,
674         variable_list grad_output) {
675       return {grad_output[0], grad_output[0] + grad_output[1]};
676     }
677   };
678 
679   Variable x = torch::randn({5, 5});
680   Variable y = torch::randn({5, 5}, torch::requires_grad());
681 
682   auto out = Inplace::apply(x, y);
683   auto& q = out[0];
684   ASSERT_TRUE(torch::equal(q, x));
685   ASSERT_TRUE(q.requires_grad());
686   q.sum().backward();
687   ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
688 }
689 
TEST(CustomAutogradTest,ReturnDuplicateInplace)690 TEST(CustomAutogradTest, ReturnDuplicateInplace) {
691   struct DoubleInplace : public Function<DoubleInplace> {
692     static variable_list forward(AutogradContext* ctx, Variable x) {
693       x.mul_(2);
694       ctx->mark_dirty({x});
695       return {x, x};
696     }
697 
698     static variable_list backward(
699         AutogradContext* ctsx,
700         variable_list grad_outputs) {
701       return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
702     }
703   };
704 
705   auto x = torch::randn({5, 5}, torch::requires_grad());
706 
707   ASSERT_THROWS_WITH(
708       DoubleInplace::apply(x), "leaf Variable that requires grad");
709   // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
710   // output");
711 
712   auto out = DoubleInplace::apply(x.clone());
713   ASSERT_TRUE(torch::equal(out[0], out[1]));
714 }
715 
TEST(CustomAutogradTest,ReturnDuplicate)716 TEST(CustomAutogradTest, ReturnDuplicate) {
717   struct DoubleDuplicate : public Function<DoubleDuplicate> {
718     static variable_list forward(AutogradContext* ctx, Variable x) {
719       auto output = x * 2;
720       return {output, output};
721     }
722 
723     static variable_list backward(
724         AutogradContext* ctx,
725         variable_list grad_outputs) {
726       return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
727     }
728   };
729 
730   auto x = torch::randn({5, 5}, torch::requires_grad());
731   auto out = DoubleDuplicate::apply(x);
732   ASSERT_TRUE(torch::equal(out[0], out[1]));
733 }
734 
TEST(CustomAutogradTest,SaveEmptyForBackward)735 TEST(CustomAutogradTest, SaveEmptyForBackward) {
736   struct MyFunction : public Function<MyFunction> {
737     static Variable forward(AutogradContext* ctx, Variable input) {
738       ctx->save_for_backward({Variable(), input, Variable()});
739       return input * input;
740     }
741 
742     static variable_list backward(
743         AutogradContext* ctx,
744         variable_list grad_output) {
745       auto saved = ctx->get_saved_variables();
746       EXPECT_FALSE(saved[0].defined());
747       EXPECT_FALSE(saved[2].defined());
748       return {saved[1] * 2 * grad_output[0]};
749     }
750   };
751 
752   Variable x = torch::randn({5, 5}, torch::requires_grad());
753   auto y = MyFunction::apply(x);
754   y.sum().backward();
755   ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
756 }
757 
TEST(CustomAutogradTest,InvalidGradients)758 TEST(CustomAutogradTest, InvalidGradients) {
759   struct MyFunction : public Function<MyFunction> {
760     static Variable forward(AutogradContext* ctx, Variable x) {
761       return x * 2;
762     }
763 
764     static variable_list backward(
765         AutogradContext* ctsx,
766         variable_list grad_outputs) {
767       return {
768           torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
769     }
770   };
771 
772   auto input1 =
773       torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
774   ASSERT_THROWS_WITH(
775       MyFunction::apply(input1).sum().backward(), "expected shape");
776   auto input2 =
777       torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
778 }
779 
TEST(CustomAutogradTest,NoGradInput)780 TEST(CustomAutogradTest, NoGradInput) {
781   struct MyFunction : public Function<MyFunction> {
782     static Variable forward(AutogradContext*, Variable x) {
783       return x;
784     }
785 
786     static variable_list backward(
787         AutogradContext*,
788         variable_list grad_outputs) {
789       return grad_outputs;
790     }
791   };
792 
793   Variable x = torch::randn({5, 5}, torch::requires_grad());
794   Variable y;
795   {
796     at::NoGradGuard no_grad;
797     y = MyFunction::apply(x);
798   }
799 
800   ASSERT_TRUE(x.requires_grad());
801   ASSERT_FALSE(y.grad_fn());
802 }
803 
TEST(CustomAutogradTest,TooManyGrads)804 TEST(CustomAutogradTest, TooManyGrads) {
805   struct MyFunction : public Function<MyFunction> {
806     static Variable forward(AutogradContext*, Variable input) {
807       return input;
808     }
809 
810     static variable_list backward(AutogradContext*, variable_list grad_output) {
811       grad_output.insert(grad_output.end(), {Variable(), Variable()});
812       return grad_output;
813     }
814   };
815 }
816 
TEST(CustomAutogradTest,DepNoGrad)817 TEST(CustomAutogradTest, DepNoGrad) {
818   struct F1 : public Function<F1> {
819     static variable_list forward(AutogradContext* ctx, Variable input) {
820       auto out = torch::randn(input.sizes());
821       ctx->mark_non_differentiable({out});
822       return {input, out};
823     }
824 
825     static variable_list backward(
826         AutogradContext* ctx,
827         variable_list grad_output) {
828       return {grad_output[0]};
829     }
830   };
831 
832   struct F2 : public Function<F2> {
833     static Variable forward(AutogradContext*, Variable input, Variable ignore) {
834       return input;
835     }
836 
837     static variable_list backward(AutogradContext*, variable_list grad_output) {
838       return {grad_output[0], Variable()};
839     }
840   };
841 
842   auto x = torch::randn(5, torch::requires_grad());
843   auto out = F1::apply(x);
844   Variable &a = out[0], &b = out[1];
845   b = b + 1; // Separate F1 and F2 by another operation
846   ASSERT_TRUE(a.requires_grad());
847   ASSERT_FALSE(b.requires_grad());
848 
849   auto c = F2::apply(a, b);
850   c.backward(torch::ones(c.sizes()), false, false);
851   ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
852 }
853 
TEST(CustomAutogradTest,Reentrant)854 TEST(CustomAutogradTest, Reentrant) {
855   static Variable y_data = torch::randn({2, 2});
856   struct Reenter : public Function<Reenter> {
857     static Variable forward(AutogradContext* ctx, Variable input) {
858       Variable output;
859       {
860         at::AutoGradMode enable_grad(true);
861         auto x = make_variable(input.tensor_data(), true);
862         auto y = make_variable(y_data.tensor_data(), true);
863         output = x * y;
864 
865         ctx->saved_data["x"] = x;
866         ctx->saved_data["y"] = y;
867         ctx->saved_data["output_var"] = output;
868       }
869       return output.detach();
870     }
871 
872     static variable_list backward(
873         AutogradContext* ctx,
874         variable_list grad_output) {
875       {
876         at::AutoGradMode enable_grad(true);
877         auto out = ctx->saved_data["output_var"].toTensor();
878         out.sum().backward();
879       }
880       return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
881     }
882   };
883 
884   auto x = torch::randn({2, 2}, torch::requires_grad());
885   auto out = Reenter::apply(x);
886   out.sum().backward();
887   ASSERT_VARIABLE_EQ(x.grad(), y_data);
888 }
889 
890 // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
891 // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
TEST(CustomAutogradTest,DeepReentrant)892 TEST(CustomAutogradTest, DeepReentrant) {
893   struct DeepReenter : public Function<DeepReenter> {
894     static Variable forward(AutogradContext* ctx, Variable x) {
895       {
896         at::AutoGradMode enable_grad(true);
897         ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
898       }
899       return ctx->saved_data["x"].toTensor().detach();
900     }
901 
902     static variable_list backward(
903         AutogradContext* ctx,
904         variable_list grad_output) {
905       if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
906         return grad_output;
907       }
908       {
909         at::AutoGradMode enable_grad(true);
910         apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
911         return grad_output;
912       }
913     }
914   };
915 
916   // This should not stack overflow
917   auto v =
918       torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
919   DeepReenter::apply(v).sum().backward();
920 }
921 
TEST(CustomAutogradTest,ReentrantPriority)922 TEST(CustomAutogradTest, ReentrantPriority) {
923   static std::vector<int> order;
924 
925   struct MyFunction : public Function<MyFunction> {
926     static Variable forward(AutogradContext*, Variable x) {
927       return x;
928     }
929 
930     static variable_list backward(AutogradContext*, variable_list grad) {
931       order.push_back(0);
932       return grad;
933     }
934   };
935 
936   struct Reenter : public Function<Reenter> {
937     static Variable forward(AutogradContext* ctx, Variable x) {
938       {
939         at::AutoGradMode enable_grad(true);
940         ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
941       }
942       return ctx->saved_data["x"].toTensor().detach();
943     }
944 
945     static variable_list backward(
946         AutogradContext* ctx,
947         variable_list grad_output) {
948       order.push_back(1);
949       if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
950         return grad_output;
951       }
952       {
953         at::AutoGradMode enable_grad(true);
954         apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
955         return grad_output;
956       }
957     }
958   };
959 
960   auto a = MyFunction::apply(
961       torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
962   auto b = Reenter::apply(
963       torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
964   auto v = a * b;
965   v.backward();
966 
967   // All the reentrant tasks should be prioritized over the MyFunction backward
968   // task.
969   ASSERT_EQ(order.size(), 10);
970   ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
971   ASSERT_EQ(order.back(), 0);
972   // Clear static variable in case test get executed in a loop
973   order.clear();
974 }
975 
TEST(CustomAutogradTest,Hooks)976 TEST(CustomAutogradTest, Hooks) {
977   Variable x = torch::ones({5, 5}, torch::requires_grad());
978   Variable y = torch::ones({5, 5}) * 4;
979   y.set_requires_grad(true);
980 
981   int counter = 0;
982 
983   std::function<void(int, Variable)> bw_hook(
984       [&counter](int inc, Variable grad) { counter += inc; });
985 
986   Variable z = x * x + x * 2 + x * y + y;
987   x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
988   auto hook_1 =
989       z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
990   z.backward(torch::ones({5, 5}), true, true);
991   ASSERT_EQ(counter, 1);
992 
993   auto hook_2 =
994       z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
995   z.backward(torch::ones({5, 5}), true, true);
996   ASSERT_EQ(counter, 4);
997 
998   z.remove_hook(hook_2);
999   z.backward(torch::ones({5, 5}), true, true);
1000   ASSERT_EQ(counter, 5);
1001 
1002   std::function<Variable(Variable)> bw_hook_modify(
1003       [](Variable grad) { return grad.mul(2); });
1004 
1005   z.remove_hook(hook_1);
1006   z.register_hook(bw_hook_modify);
1007   y.grad().zero_();
1008   z.backward(torch::ones({5, 5}), true, false);
1009   ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
1010 
1011   y.register_hook(bw_hook_modify);
1012   y.grad().zero_();
1013   z.backward(torch::ones({5, 5}), false, false);
1014   ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
1015 
1016   ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
1017 }
1018 
TEST(CustomAutogradTest,HooksInplace)1019 TEST(CustomAutogradTest, HooksInplace) {
1020   auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1021 
1022   int hook1_count = 0;
1023   auto hook1 = ([&hook1_count](Variable grad) {
1024     hook1_count++;
1025     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1026   });
1027 
1028   int hook2_count = 0;
1029   auto hook2 = ([&hook2_count](Variable grad) {
1030     hook2_count++;
1031     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1032   });
1033 
1034   a.register_hook(hook1);
1035   a.mul_(2);
1036   a.register_hook(hook2);
1037 
1038   auto out = (a + 1).sum();
1039   out.backward();
1040 
1041   ASSERT_EQ(hook1_count, 1);
1042   ASSERT_EQ(hook2_count, 1);
1043 }
1044 
TEST(CustomAutogradTest,HooksInplaceWithRetainsGrad)1045 TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
1046   auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1047 
1048   int hook1_count = 0;
1049   auto hook1 = ([&hook1_count](Variable grad) {
1050     hook1_count++;
1051     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1052   });
1053 
1054   int hook2_count = 0;
1055   auto hook2 = ([&hook2_count](Variable grad) {
1056     hook2_count++;
1057     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1058   });
1059 
1060   int hook3_count = 0;
1061   auto hook3 = ([&hook3_count](Variable grad) {
1062     hook3_count++;
1063     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1064   });
1065 
1066   a.register_hook(hook1);
1067   a.retain_grad();
1068   a.register_hook(hook2);
1069 
1070   a.mul_(2);
1071   a.register_hook(hook3);
1072 
1073   auto out = (a + 1).sum();
1074   out.backward();
1075 
1076   ASSERT_EQ(hook1_count, 1);
1077   ASSERT_EQ(hook2_count, 1);
1078   ASSERT_EQ(hook3_count, 1);
1079 
1080   ASSERT_TRUE(a.retains_grad());
1081   ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1082 }
1083 
TEST(CustomAutogradTest,HooksInplaceTwiceWithRetainsGrad)1084 TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
1085   auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1086 
1087   int hook1_count = 0;
1088   auto hook1 = ([&hook1_count](Variable grad) {
1089     hook1_count++;
1090     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1091   });
1092 
1093   int hook2_count = 0;
1094   auto hook2 = ([&hook2_count](Variable grad) {
1095     hook2_count++;
1096     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1097   });
1098 
1099   int hook3_count = 0;
1100   auto hook3 = ([&hook3_count](Variable grad) {
1101     hook3_count++;
1102     ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1103   });
1104 
1105   a.register_hook(hook1);
1106   a.retain_grad();
1107   a.register_hook(hook2);
1108 
1109   a.mul_(2);
1110   a.mul_(2);
1111   a.register_hook(hook3);
1112 
1113   auto out = (a + 1).sum();
1114   out.backward();
1115 
1116   ASSERT_EQ(hook1_count, 1);
1117   ASSERT_EQ(hook2_count, 1);
1118   ASSERT_EQ(hook3_count, 1);
1119 
1120   ASSERT_TRUE(a.retains_grad());
1121   ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1122 }
1123 
TEST(CustomAutogradTest,HookNone)1124 TEST(CustomAutogradTest, HookNone) {
1125   struct NoneGradientFunction : public Function<NoneGradientFunction> {
1126     static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
1127       return {x, y};
1128     }
1129 
1130     static variable_list backward(AutogradContext* ctx, variable_list grad) {
1131       return {grad[0], Variable()};
1132     }
1133   };
1134 
1135   bool was_called = false;
1136 
1137   auto hook = ([&was_called](Variable grad) {
1138     ASSERT_TRUE(grad.defined());
1139     was_called = true;
1140   });
1141 
1142   auto x = torch::randn({5, 5}, torch::requires_grad());
1143   auto y = torch::randn({5, 5});
1144 
1145   auto out = NoneGradientFunction::apply(x, y);
1146   Variable rx = x[0], ry = x[1];
1147 
1148   rx.register_hook(hook);
1149   ry.register_hook(hook);
1150   (rx + ry).sum().backward();
1151   ASSERT_TRUE(was_called);
1152 }
1153 
TEST(CustomAutogradTest,BackwardWithInputs)1154 TEST(CustomAutogradTest, BackwardWithInputs) {
1155   Variable x = torch::randn({5, 5}, torch::requires_grad());
1156   Variable y = torch::randn({5, 5}, torch::requires_grad());
1157   Variable z = x * x + x * y + y * y;
1158   Variable x_grad_expected = 2 * x + y;
1159   Variable y_grad_expected = x + 2 * y;
1160 
1161   z.backward(torch::ones({5, 5}), false, false, {x});
1162 
1163   ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1164   ASSERT_FALSE(y.grad().defined());
1165 }
1166 
TEST(CustomAutogradTest,BackwardWithEmptyInputs)1167 TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
1168   Variable x = torch::randn({5, 5}, torch::requires_grad());
1169   Variable y = torch::randn({5, 5}, torch::requires_grad());
1170   Variable z = x * x + x * y + y * y;
1171   Variable x_grad_expected = 2 * x + y;
1172   Variable y_grad_expected = x + 2 * y;
1173   ASSERT_THROWS_WITH(
1174       z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
1175       "cannot be empty");
1176 }
1177 
TEST(CustomAutogradTest,BackwardWithNonLeafInputs)1178 TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
1179   Variable x = torch::randn({5, 5}, torch::requires_grad());
1180   Variable y = torch::randn({5, 5}, torch::requires_grad());
1181   Variable z = x * x;
1182   Variable w = y * z + x * y + y * y;
1183 
1184   Variable x_grad_expected = 2 * x * y + y;
1185   Variable z_grad_expected = y;
1186 
1187   w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
1188 
1189   ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1190   ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
1191   ASSERT_FALSE(y.grad().defined());
1192 }
1193 
TEST(CustomAutogradTest,BackwardWithCreateGraphWarns)1194 TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
1195   c10::WarningUtils::WarnAlways guard(true);
1196 
1197   torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
1198   auto z = x * x;
1199   {
1200     WarningCapture warnings;
1201     z.backward(torch::ones({5, 5}), std::nullopt, true);
1202     ASSERT_TRUE(
1203         warnings.str().find("Using backward() with create_graph=True") !=
1204         std::string::npos);
1205   }
1206 
1207   {
1208     WarningCapture warnings;
1209     torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true);
1210     ASSERT_TRUE(
1211         warnings.str().find("Using backward() with create_graph=True") !=
1212         std::string::npos);
1213   }
1214 }
1215 
1216 /**
1217  * Tests for AutogradNotImplementedFallback
1218  * - Check that we created the NotImplemented kernel when inputs require grad
1219  *   but when no inputs require grad, we should not create this node
1220  * - check_inplace logic
1221  * - view ops
1222  * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
1223  * test non-NDEBUG builds.
1224  * - tensorlist input and output
1225  * - multiple outputs / non-tensor output
1226  * - rebase_history vs set_history
1227  */
1228 namespace {
1229 
inplace_op(const torch::Tensor & self,const torch::Tensor & other)1230 torch::Tensor inplace_op(
1231     const torch::Tensor& self,
1232     const torch::Tensor& other) {
1233   return self.add_(other);
1234 }
1235 
two_arg_inplace_op(const torch::Tensor & self,const torch::Tensor & other)1236 std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
1237     const torch::Tensor& self,
1238     const torch::Tensor& other) {
1239   other.add_(self);
1240   self.add_(other);
1241   return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1242 }
1243 
two_pairs_of_view_op(const torch::Tensor & self,const torch::Tensor & other)1244 std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
1245     const torch::Tensor& self,
1246     const torch::Tensor& other) {
1247   // This is not allowed. We test below that this calling into the boxed kernel
1248   // will raise an error
1249   return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1250 }
1251 
non_first_view_op(const torch::Tensor & self,const torch::Tensor & other)1252 std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
1253     const torch::Tensor& self,
1254     const torch::Tensor& other) {
1255   // This is not allowed. We test below that this calling into the boxed kernel
1256   // will raise an error
1257   return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
1258 }
1259 
ret_single_non_tensor(const torch::Tensor & self,const torch::Tensor & other)1260 int64_t ret_single_non_tensor(
1261     const torch::Tensor& self,
1262     const torch::Tensor& other) {
1263   return 12;
1264 }
1265 
opt_op(const torch::Tensor & self,const std::optional<at::Tensor> & other)1266 torch::Tensor opt_op(
1267     const torch::Tensor& self,
1268     const std::optional<at::Tensor>& other) {
1269   if (other.has_value()) {
1270     return self + other.value();
1271   } else {
1272     return self.clone();
1273   }
1274 }
1275 
my_custom_op(const torch::Tensor & self,const torch::Tensor & other)1276 torch::Tensor my_custom_op(
1277     const torch::Tensor& self,
1278     const torch::Tensor& other) {
1279   return self + other;
1280 }
1281 
ret_tuple_non_tensor(const torch::Tensor & self,const torch::Tensor & other)1282 std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
1283     const torch::Tensor& self,
1284     const torch::Tensor& other) {
1285   auto a = self - other;
1286   auto b = self + other;
1287   return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
1288 }
1289 
view_op(const torch::Tensor & self)1290 torch::Tensor view_op(const torch::Tensor& self) {
1291   return self.alias();
1292 }
1293 
view_op_with_extra_arg(const torch::Tensor & self,const torch::Tensor & other)1294 torch::Tensor view_op_with_extra_arg(
1295     const torch::Tensor& self,
1296     const torch::Tensor& other) {
1297   return self.alias();
1298 }
1299 
ret_tensor_vector_view(const torch::Tensor & self,const torch::Tensor & other)1300 std::vector<torch::Tensor> ret_tensor_vector_view(
1301     const torch::Tensor& self,
1302     const torch::Tensor& other) {
1303   return {self.alias(), self.alias()};
1304 }
1305 
ret_tensor_vector(const torch::Tensor & self,const torch::Tensor & other)1306 std::vector<at::Tensor> ret_tensor_vector(
1307     const torch::Tensor& self,
1308     const torch::Tensor& other) {
1309   std::vector<at::Tensor> out;
1310   out.push_back(self + other);
1311   out.push_back(self - other);
1312   return out;
1313 }
1314 
tensorlist_op(const torch::Tensor & self,at::TensorList other)1315 torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
1316   const auto& res = self.clone();
1317   for (const auto& t : other) {
1318     res.add_(t);
1319   }
1320   return res;
1321 }
1322 
1323 #define REGISTER_TEST_OP(name, schema, fn)                                 \
1324   auto m = MAKE_TORCH_LIBRARY(_test);                                      \
1325   m.def(schema);                                                           \
1326   auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd);              \
1327   auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);                        \
1328   auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView);  \
1329   m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn));                   \
1330   m_autograd.impl(                                                         \
1331       name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
1332   m_inplaceorview.impl(                                                    \
1333       name,                                                                \
1334       c10::DispatchKey::ADInplaceOrView,                                   \
1335       autogradNotImplementedInplaceOrViewFallback());
1336 
1337 template <typename F>
assertBasicChecks(F op)1338 void assertBasicChecks(F op) {
1339   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1340   auto b = torch::tensor({1.}, {torch::kFloat32});
1341   auto c = torch::tensor({1.}, {torch::kFloat32});
1342 
1343   // If any inputs require grad,
1344   auto out1 = op(a, b);
1345   ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
1346 
1347   // # Should not have grad_fn if none require grad
1348   auto out2 = op(b, c);
1349   ASSERT_THROWS_WITH(
1350       out2.backward(),
1351       "element 0 of tensors does not require grad and does not have a grad_fn");
1352 
1353   // TODO: Forward AD Tests?
1354 }
1355 
1356 } // namespace
1357 
TEST(TestAutogradNotImplementedFallback,RetSingleNonTensor)1358 TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
1359   REGISTER_TEST_OP(
1360       "ret_single_non_tensor",
1361       "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
1362       ret_single_non_tensor);
1363   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1364       "_test::ret_single_non_tensor", "");
1365   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1366     return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
1367         opHandle, _1, _2);
1368   };
1369 
1370   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1371   auto b = torch::tensor({1.}, {torch::kFloat32});
1372 
1373   ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
1374 }
1375 
TEST(TestAutogradNotImplementedFallback,InplaceOp)1376 TEST(TestAutogradNotImplementedFallback, InplaceOp) {
1377   REGISTER_TEST_OP(
1378       "inplace_op",
1379       "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
1380       inplace_op);
1381   auto opHandle =
1382       c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
1383   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1384     return callOpUnboxed<
1385         torch::Tensor,
1386         const torch::Tensor&,
1387         const torch::Tensor&>(opHandle, _1, _2);
1388   };
1389 
1390   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1391   auto b = torch::tensor({1.}, {torch::kFloat32});
1392 
1393   // Check in-place
1394   ASSERT_THROWS_WITH(
1395       op(a, b),
1396       "a leaf Variable that requires grad is being used in an in-place operation");
1397   op(b, a);
1398   a = a.clone();
1399   b = b.clone();
1400   auto c = op(a, b);
1401   ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
1402 
1403   // Test in-place on view
1404   auto base =
1405       torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1406   auto view = base.view(-1);
1407   auto t = torch::tensor({1.}, {torch::kFloat32});
1408 
1409   torch::Tensor v_nograd;
1410   {
1411     c10::NoGradGuard guard;
1412     v_nograd = base.view(-1);
1413     op(v_nograd, t);
1414   }
1415 
1416   ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
1417   ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
1418   ASSERT_THAT(
1419       op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
1420 }
1421 
TEST(TestAutogradNotImplementedFallback,DoubleInplaceOp)1422 TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
1423   REGISTER_TEST_OP(
1424       "two_arg_inplace_op",
1425       "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
1426       two_arg_inplace_op);
1427   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1428       "_test::two_arg_inplace_op", "");
1429   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1430     return callOpUnboxed<
1431         std::tuple<torch::Tensor, torch::Tensor>,
1432         const torch::Tensor&,
1433         const torch::Tensor&>(opHandle, _1, _2);
1434   };
1435   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1436   auto b = torch::tensor({1.}, {torch::kFloat32});
1437 
1438   // Both are modified in-place!
1439   ASSERT_THROWS_WITH(
1440       op(a, b),
1441       "a leaf Variable that requires grad is being used in an in-place operation");
1442   ASSERT_THROWS_WITH(
1443       op(b, a),
1444       "a leaf Variable that requires grad is being used in an in-place operation");
1445 
1446   auto c =
1447       torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1448   auto d =
1449       torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1450 
1451   auto saved_version_c = c._version();
1452   auto saved_version_d = d._version();
1453   op(c, d);
1454   ASSERT_NE(c._version(), saved_version_c);
1455   ASSERT_NE(d._version(), saved_version_d);
1456 }
1457 
TEST(TestAutogradNotImplementedFallback,OptOp)1458 TEST(TestAutogradNotImplementedFallback, OptOp) {
1459   REGISTER_TEST_OP(
1460       "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
1461   auto opHandle =
1462       c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
1463   auto op = [&](const torch::Tensor& _1,
1464                 const std::optional<torch::Tensor>& _2) {
1465     return callOpUnboxed<
1466         torch::Tensor,
1467         const torch::Tensor&,
1468         const std::optional<torch::Tensor>&>(opHandle, _1, _2);
1469   };
1470 
1471   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1472   auto b = torch::tensor({1.}, {torch::kFloat32});
1473 
1474   ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
1475   ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
1476 }
1477 
TEST(TestAutogradNotImplementedFallback,OutOfPlaceAddition)1478 TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
1479   REGISTER_TEST_OP(
1480       "my_custom_op",
1481       "_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
1482       my_custom_op);
1483   auto opHandle =
1484       c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
1485   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1486     return callOpUnboxed<
1487         torch::Tensor,
1488         const torch::Tensor&,
1489         const torch::Tensor&>(opHandle, _1, _2);
1490   };
1491 
1492   assertBasicChecks(op);
1493 }
1494 
TEST(TestAutogradNotImplementedFallback,RetTupleNonTensor)1495 TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
1496   REGISTER_TEST_OP(
1497       "ret_tuple_non_tensor",
1498       "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
1499       ret_tuple_non_tensor);
1500   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1501       "_test::ret_tuple_non_tensor", "");
1502   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1503     auto out = callOpUnboxed<
1504         std::tuple<torch::Tensor, torch::Tensor, int64_t>,
1505         const torch::Tensor&,
1506         const torch::Tensor&>(opHandle, _1, _2);
1507     auto [out0, out1, out2] = std::move(out);
1508     return out0;
1509   };
1510 
1511   assertBasicChecks(op);
1512 }
1513 
TEST(TestAutogradNotImplementedFallback,ViewOp)1514 TEST(TestAutogradNotImplementedFallback, ViewOp) {
1515   REGISTER_TEST_OP(
1516       "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
1517   auto opHandle =
1518       c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
1519   auto op = [&](const torch::Tensor& _1) {
1520     return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
1521   };
1522   auto b = torch::tensor({1.}, {torch::kFloat32});
1523   auto v = op(b);
1524   ASSERT_TRUE(v.is_view());
1525   ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
1526 
1527   auto b1 =
1528       torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1529   auto v1 = op(b1);
1530   ASSERT_TRUE(v1.is_view());
1531   ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
1532 
1533   // Test inplace on view
1534   auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1535 
1536   // raise on rebase_history when it refreshes grad_fn
1537   ASSERT_THROWS_WITH(
1538       v1.add_(t), "which does not have a derivative implemented is forbidden");
1539   // base should not be aware of the views, so this is still okay
1540   b1.add_(t);
1541   ASSERT_THROWS_WITH(
1542       v1.grad_fn(),
1543       "which does not have a derivative implemented is forbidden");
1544 }
1545 
TEST(TestAutogradNotImplementedFallback,ViewOpWithExtraArg)1546 TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
1547   REGISTER_TEST_OP(
1548       "view_op_with_extra_arg",
1549       "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
1550       view_op_with_extra_arg);
1551   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1552       "_test::view_op_with_extra_arg", "");
1553   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1554     return callOpUnboxed<
1555         torch::Tensor,
1556         const torch::Tensor&,
1557         const torch::Tensor&>(opHandle, _1, _2);
1558   };
1559   assertBasicChecks(op);
1560   auto a = torch::tensor({1.}, {torch::kFloat32});
1561   auto b = torch::tensor({2.}, {torch::kFloat32});
1562   auto out1 = op(a, b);
1563   ASSERT_TRUE(out1.is_view());
1564   ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1565 }
1566 
TEST(TestAutogradNotImplementedFallback,RetTensorVectorView)1567 TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
1568   REGISTER_TEST_OP(
1569       "ret_tensor_vector_view",
1570       "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
1571       ret_tensor_vector_view);
1572   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1573       "_test::ret_tensor_vector_view", "");
1574   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1575     return callOpUnboxed<
1576         std::vector<at::Tensor>,
1577         const torch::Tensor&,
1578         const torch::Tensor&>(opHandle, _1, _2);
1579   };
1580   auto a = torch::tensor({1.}, {torch::kFloat32});
1581   auto b = torch::tensor({1.}, {torch::kFloat32});
1582   auto out = op(a, b);
1583   ASSERT_TRUE(out[0].is_view());
1584   ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1585   ASSERT_TRUE(out[1].is_view());
1586   ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1587 }
1588 
TEST(TestAutogradNotImplementedFallback,DoubleViewOP)1589 TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
1590   REGISTER_TEST_OP(
1591       "two_pairs_of_view_op",
1592       "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
1593       two_pairs_of_view_op);
1594   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1595       "_test::two_pairs_of_view_op", "");
1596   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1597     return callOpUnboxed<
1598         std::tuple<torch::Tensor, torch::Tensor>,
1599         const torch::Tensor&,
1600         const torch::Tensor&>(opHandle, _1, _2);
1601   };
1602   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1603   auto b = torch::tensor({1.}, {torch::kFloat32});
1604   ASSERT_THROWS_WITH(
1605       op(a, b),
1606       "Expected only a single output in the operator schema to have a non-write alias annotation");
1607 }
1608 
TEST(TestAutogradNotImplementedFallback,NonFirstViewOP)1609 TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
1610   REGISTER_TEST_OP(
1611       "non_first_view_op",
1612       "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
1613       non_first_view_op);
1614   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1615       "_test::non_first_view_op", "");
1616   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1617     return callOpUnboxed<
1618         std::tuple<torch::Tensor, torch::Tensor>,
1619         const torch::Tensor&,
1620         const torch::Tensor&>(opHandle, _1, _2);
1621   };
1622   auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1623   auto b = torch::tensor({1.}, {torch::kFloat32});
1624   ASSERT_THROWS_WITH(
1625       op(a, b), "can only create view relationships between the first");
1626 }
1627 
TEST(TestAutogradNotImplementedFallback,RetTensorVector)1628 TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
1629   REGISTER_TEST_OP(
1630       "ret_tensor_vector",
1631       "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
1632       ret_tensor_vector);
1633   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1634       "_test::ret_tensor_vector", "");
1635   auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1636     return callOpUnboxed<
1637         std::vector<at::Tensor>,
1638         const torch::Tensor&,
1639         const torch::Tensor&>(opHandle, _1, _2)[0];
1640   };
1641   assertBasicChecks(op);
1642 }
1643 
TEST(TestAutogradNotImplementedFallback,TensorlistOp)1644 TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
1645   REGISTER_TEST_OP(
1646       "tensorlist_op",
1647       "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
1648       tensorlist_op);
1649   auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1650       "_test::tensorlist_op", "");
1651   auto op = [&](torch::Tensor _1, at::TensorList _2) {
1652     return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
1653         opHandle, _1, _2);
1654   };
1655 
1656   auto a = torch::tensor({1.}, {torch::kFloat32});
1657   auto b = torch::tensor({1.}, {torch::kFloat32});
1658   auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1659   std::vector<torch::Tensor> vec = {b, c};
1660   auto out = op(a, vec);
1661 
1662   ASSERT_THROWS_WITH(
1663       torch::autograd::grad({out}, {vec[0]}),
1664       "element 0 of the input tensors does not require grad");
1665   ASSERT_THROWS_WITH(
1666       torch::autograd::grad({out}, {vec[1]}), "is not implemented");
1667 
1668   ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
1669 }
1670 
1671 // TODO add these tests if needed
1672 // test_once_differentiable
1673 // test_sparse_backward
1674 // test_save_output_nr
1675 // test_free_deep_graph_pyfunction
1676 // test_naughty_anomaly_access
1677 // test_naughty_autograd-function_stashing_ctx
1678 // test_custom_autograd_repeated_grad_grad
1679 // test_return_leaf
1680 // test_anomaly_detect_nan
1681 // test_no_grad_copy
1682