1 #include <ATen/core/boxing/impl/test_helpers.h>
2 #include <gtest/gtest.h>
3
4 #include <ATen/core/op_registration/op_registration.h>
5 #include <torch/torch.h>
6
7 #include <torch/csrc/autograd/FunctionsManual.h>
8 #include <torch/csrc/autograd/functions/basic_ops.h>
9
10 #include <test/cpp/api/support.h>
11
12 using namespace torch::autograd;
13 using namespace torch::test;
14
15 #define ASSERT_VARIABLE_EQ(a, b) ASSERT_TRUE(torch::allclose((a), (b)))
16 #define EXPECT_VARIABLE_EQ(a, b) EXPECT_TRUE(torch::allclose((a), (b)))
17
graph_desc(std::shared_ptr<Node> node)18 std::string graph_desc(std::shared_ptr<Node> node) {
19 if (!node) {
20 return "None";
21 }
22 auto result = node->name() + "(";
23 auto next_edges = node->next_edges();
24 for (auto& edge : next_edges) {
25 result += graph_desc(edge.function);
26 }
27 return result + ")";
28 }
29
simple_fn(const Variable & x,const Variable & y)30 Variable simple_fn(const Variable& x, const Variable& y) {
31 return x + 2 * y + x * y;
32 }
33
TEST(AutogradAPITests,RegisterHookVoidReturnAcceptsUndefinedTensor)34 TEST(AutogradAPITests, RegisterHookVoidReturnAcceptsUndefinedTensor) {
35 auto x = at::zeros({}, at::kCPU);
36 x.requires_grad_();
37 x.register_hook([](at::TensorBase x) { return; });
38 auto y = torch::autograd::UndefinedGrad().apply({x});
39 y[0].backward();
40 }
41
TEST(AutogradAPITests,RegisterHookTensorReturnAcceptsUndefinedTensor)42 TEST(AutogradAPITests, RegisterHookTensorReturnAcceptsUndefinedTensor) {
43 auto x = at::zeros({}, at::kCPU);
44 x.requires_grad_();
45 x.register_hook([](at::Tensor x) -> at::Tensor { return x; });
46 auto y = torch::autograd::UndefinedGrad().apply({x});
47 y[0].backward();
48 }
49
TEST(AutogradAPITests,BackwardSimpleTest)50 TEST(AutogradAPITests, BackwardSimpleTest) {
51 Variable x = torch::randn({2, 2}, torch::requires_grad());
52 Variable y = torch::randn({2, 2}, torch::requires_grad());
53 auto res = simple_fn(x, y);
54 backward({res.sum()}, {});
55
56 ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
57 ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2}) * 2);
58 }
59
TEST(AutogradAPITests,BackwardTest)60 TEST(AutogradAPITests, BackwardTest) {
61 Variable x = torch::randn({2, 2}, torch::requires_grad());
62 Variable y = torch::randn({2, 2}, torch::requires_grad());
63 auto res = simple_fn(x, y);
64 backward({res}, {torch::ones({2, 2})}, {}, true);
65
66 backward({res}, {torch::ones({2, 2})});
67
68 ASSERT_VARIABLE_EQ(x.grad(), 2 * (y + torch::ones({2, 2})));
69 ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2}) * 2));
70 }
71
TEST(AutogradAPITests,GradSimpleTest)72 TEST(AutogradAPITests, GradSimpleTest) {
73 // basic grad
74 Variable x = torch::randn({2, 2}, torch::requires_grad());
75 Variable y = torch::randn({2, 2}, torch::requires_grad());
76 auto res = simple_fn(x, y);
77 auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
78
79 ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
80 ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
81 }
82
TEST(AutogradAPITests,GradTest)83 TEST(AutogradAPITests, GradTest) {
84 Variable x = torch::randn({2, 2}, torch::requires_grad());
85 Variable y = torch::randn({2, 2}, torch::requires_grad());
86 auto res = simple_fn(x, y);
87 res.backward(torch::ones({2, 2}), false, true);
88
89 Variable x_grad = y + torch::ones({2, 2});
90 Variable y_grad = x + torch::ones({2, 2}) * 2;
91 ASSERT_VARIABLE_EQ(x.grad(), x_grad);
92 ASSERT_VARIABLE_EQ(y.grad(), y_grad);
93
94 Variable grad_sum = 2 * x.grad() + y.grad();
95 auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
96
97 ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
98 ASSERT_VARIABLE_EQ(x.grad(), x_grad);
99 ASSERT_VARIABLE_EQ(y.grad(), y_grad);
100 }
101
TEST(AutogradAPITests,GradNonLeafTest)102 TEST(AutogradAPITests, GradNonLeafTest) {
103 Variable x_init = torch::randn({2, 2}, torch::requires_grad());
104 Variable x = x_init;
105 Variable y = torch::randn({2, 2}, torch::requires_grad());
106 Variable grad_output = torch::ones({2, 2});
107
108 for (int i = 0; i < 5; ++i) {
109 auto res = simple_fn(x, y);
110 auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
111
112 Variable grad_x_expected = y + torch::ones({2, 2});
113 ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
114 ASSERT_FALSE(x.grad().defined());
115 ASSERT_FALSE(y.grad().defined());
116 x = x + 0.05 * input_grads[0];
117 }
118
119 float val_init = simple_fn(x_init, y).sum().item().toFloat();
120 float val_final = simple_fn(x, y).sum().item().toFloat();
121 ASSERT_TRUE(val_final > val_init);
122
123 x.backward(grad_output, false, true);
124 ASSERT_TRUE(x_init.grad().defined());
125 ASSERT_TRUE(y.grad().defined());
126 }
127
TEST(AutogradAPITests,GradUnreachableTest)128 TEST(AutogradAPITests, GradUnreachableTest) {
129 Variable x = torch::ones({1}, torch::requires_grad());
130 Variable y = torch::ones({1}, torch::requires_grad());
131
132 Variable z = x * 2;
133 Variable w = y * 2;
134
135 auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
136 ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
137 ASSERT_FALSE(grad_res[1].defined());
138
139 // This is slightly different than the case above, because z doesn't even
140 // have a grad accumulator allocated.
141 z = torch::ones({1}, torch::requires_grad());
142 grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
143
144 ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
145 ASSERT_FALSE(grad_res[1].defined());
146
147 // allow_unused=False, but grads contains None inside, should throw
148 ASSERT_THROWS_WITH(
149 grad({x * 2}, {x, y}, {}, {}, false, false), "Set allow_unused=True");
150 }
151
TEST(CustomAutogradTest,GradUnreachableDiscoveryTest)152 TEST(CustomAutogradTest, GradUnreachableDiscoveryTest) {
153 // Test that certain nodes are not erroneously executed when an input
154 // is unreachable. See #39784
155 struct MyFunction : public Function<MyFunction> {
156 static Variable forward(AutogradContext* ctx, Variable var) {
157 return var;
158 }
159
160 static variable_list backward(
161 AutogradContext* ctx,
162 variable_list grad_output) {
163 ADD_FAILURE() << "This node should not be executed!";
164 return grad_output;
165 }
166 };
167
168 auto x = torch::randn(1, torch::requires_grad());
169 auto x1 = torch::randn(1);
170 auto x2 = MyFunction::apply(x + x1);
171
172 auto y = torch::randn(1, torch::requires_grad());
173 auto grad_res = torch::autograd::grad({x2}, {y}, {}, {}, false, true);
174 ASSERT_FALSE(grad_res[0].defined());
175 }
176
TEST(AutogradAPITests,EmptyInput)177 TEST(AutogradAPITests, EmptyInput) {
178 Variable x = torch::ones({1}, torch::requires_grad());
179 ASSERT_THROWS_WITH(
180 grad({x * 2}, /*inputs=*/{}, {x}), "grad requires non-empty inputs.");
181 }
182
TEST(AutogradAPITests,RetainGrad)183 TEST(AutogradAPITests, RetainGrad) {
184 auto input = torch::rand({1, 3}, torch::requires_grad());
185 auto h1 = input * 3;
186 auto out = (h1 * h1).sum();
187
188 {
189 // Warning when grad is accessed for non-leaf tensor
190 WarningCapture warnings;
191 ASSERT_FALSE(h1.grad().defined());
192 ASSERT_TRUE(warnings.str().find("is not a leaf") != std::string::npos);
193 }
194 // It should be possible to call retain_grad() multiple times
195 h1.retain_grad();
196 h1.retain_grad();
197 {
198 // If retain_grad is true for a non-leaf tensor,
199 // there should not be any warning when grad is accessed
200 WarningCapture warnings;
201 ASSERT_FALSE(h1.grad().defined());
202 ASSERT_FALSE(warnings.str().find("is not a leaf") != std::string::npos);
203 }
204
205 // Gradient should be accumulated
206 // NOLINTNEXTLINE(bugprone-argument-comment)
207 out.backward({}, /*keep_graph=*/true);
208 ASSERT_VARIABLE_EQ(h1 * 2, h1.grad());
209 // NOLINTNEXTLINE(bugprone-argument-comment)
210 out.backward({}, /*keep_graph=*/true);
211 ASSERT_VARIABLE_EQ(h1 * 4, h1.grad());
212
213 {
214 torch::NoGradGuard no_grad;
215 input.grad().zero_();
216 }
217 // It should be a no-op for leaves
218 input.retain_grad();
219 input.retain_grad();
220 out.backward();
221 ASSERT_VARIABLE_EQ(input * 18, input.grad());
222 }
223
TEST(AutogradAPITests,AnomalyMode)224 TEST(AutogradAPITests, AnomalyMode) {
225 // Needs to have backtrace as warning and then throw an error
226 torch::autograd::DetectAnomalyGuard detect_anomaly;
227 {
228 WarningCapture warnings;
229 auto x = torch::tensor({5.0}, torch::requires_grad());
230 auto y = x * x;
231 auto z = y * y;
232 y += 1;
233 ASSERT_THROWS_WITH(z.backward(), "inplace");
234 ASSERT_TRUE(
235 warnings.str().find("Traceback of forward") != std::string::npos);
236 }
237 auto double_backward_produce_nan = [](bool should_throw) {
238 auto x = torch::tensor({0.0}, torch::requires_grad());
239 auto y = x.pow(1.5);
240 auto gr =
241 // NOLINTNEXTLINE(bugprone-argument-comment)
242 grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
243 if (should_throw) {
244 WarningCapture warnings;
245 ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});
246 , "returned nan");
247 auto msgs = warnings.messages();
248 ASSERT_EQ(msgs.size(), 2);
249 ASSERT_TRUE(
250 msgs[0].find("Traceback of forward call that caused the error") !=
251 std::string::npos);
252 ASSERT_TRUE(
253 msgs[1].find(
254 "Traceback of forward call that induced the previous calculation") !=
255 std::string::npos);
256 } else {
257 grad({gr[0]}, {x}, {torch::tensor({0.0})});
258 }
259 };
260
261 double_backward_produce_nan(true);
262 {
263 torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/false);
264 double_backward_produce_nan(false);
265 {
266 torch::autograd::DetectAnomalyGuard detect_anomaly(/*check_nan=*/true);
267 double_backward_produce_nan(true);
268 }
269 }
270 double_backward_produce_nan(true);
271 }
272
TEST(CustomAutogradTest,CustomFunctionReturnInputAsIsAndSavesIt)273 TEST(CustomAutogradTest, CustomFunctionReturnInputAsIsAndSavesIt) {
274 struct MyFunction : public Function<MyFunction> {
275 static Variable forward(
276 AutogradContext* ctx,
277 Variable var1,
278 Variable var2) {
279 ctx->save_for_backward({var1, var2});
280 return var1 * var2, var1;
281 }
282
283 static variable_list backward(
284 AutogradContext* ctx,
285 variable_list grad_output) {
286 return {};
287 }
288 };
289
290 Variable x = torch::randn({5, 5}, torch::requires_grad());
291 Variable y = torch::randn({5, 5}, torch::requires_grad());
292 MyFunction::apply(x, y);
293 }
294
TEST(CustomAutogradTest,CustomFunction)295 TEST(CustomAutogradTest, CustomFunction) {
296 struct MyFunction : public Function<MyFunction> {
297 static Variable forward(
298 AutogradContext* ctx,
299 Variable var1,
300 int mul,
301 Variable var2) {
302 ctx->saved_data["mul"] = mul;
303 ctx->save_for_backward({var1, var2});
304 return var1 + mul * var2 + var1 * var2;
305 }
306
307 static variable_list backward(
308 AutogradContext* ctx,
309 variable_list grad_output) {
310 int mul = ctx->saved_data["mul"].toInt();
311 auto saved = ctx->get_saved_variables();
312 auto var1 = saved[0];
313 auto var2 = saved[1];
314 variable_list output = {
315 grad_output[0] + grad_output[0] * var2,
316 Variable(),
317 grad_output[0] * mul + grad_output[0] * var1};
318 return output;
319 }
320 };
321
322 Variable x = torch::randn({5, 5}, torch::requires_grad());
323 Variable y = torch::randn({5, 5}, torch::requires_grad());
324 auto res = MyFunction::apply(x, 2, y);
325 auto go = torch::ones({}, torch::requires_grad());
326 res.sum().backward(go, false, true);
327
328 ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
329 ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}) * 2);
330 }
331
TEST(CustomAutogradTest,CustomFunctionWithTensorList)332 TEST(CustomAutogradTest, CustomFunctionWithTensorList) {
333 struct MyFunction : public Function<MyFunction> {
334 static Variable forward(AutogradContext* ctx, at::TensorList tensors) {
335 torch::autograd::variable_list vars;
336 for (const at::Tensor& tensor : tensors) {
337 vars.push_back(tensor);
338 }
339 ctx->save_for_backward(vars);
340 return tensors[0] + tensors[1] + tensors[0] * tensors[1];
341 }
342
343 static variable_list backward(
344 AutogradContext* ctx,
345 variable_list grad_output) {
346 auto saved = ctx->get_saved_variables();
347 auto var1 = saved[0];
348 auto var2 = saved[1];
349 variable_list output = {
350 grad_output[0] + grad_output[0] * var2,
351 grad_output[0] + grad_output[0] * var1};
352 return output;
353 }
354 };
355
356 at::Tensor x = torch::randn({5, 5}, torch::requires_grad());
357 at::Tensor y = torch::randn({5, 5}, torch::requires_grad());
358 torch::autograd::variable_list variables = {x, y};
359 at::TensorList tensors = variables;
360 auto res = MyFunction::apply(tensors);
361 auto go = torch::ones({}, torch::requires_grad());
362 res.sum().backward(go, false, true);
363
364 ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5, 5}));
365 ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5, 5}));
366 }
367
TEST(CustomAutogradTest,GraphTaskTrimEdges)368 TEST(CustomAutogradTest, GraphTaskTrimEdges) {
369 struct MyFunction : public Function<MyFunction> {
370 static Variable forward(
371 AutogradContext* ctx,
372 Variable var1,
373 Variable var2,
374 int mul,
375 bool needs_input1_grad,
376 bool needs_input2_grad) {
377 // setup the expected should and should not compute idx
378 ctx->saved_data["needs_input1_grad"] = needs_input1_grad;
379 ctx->saved_data["needs_input2_grad"] = needs_input2_grad;
380
381 ctx->saved_data["mul"] = mul;
382 ctx->save_for_backward({var1, var2});
383 return var1 + mul * var2 + var1 * var2;
384 }
385
386 static variable_list backward(
387 AutogradContext* ctx,
388 variable_list grad_output) {
389 // Test `needs_input_grad` method is working correctly.
390 // We have to test this within the backward function.
391 auto needs_input1_grad = ctx->saved_data["needs_input1_grad"].toBool();
392 auto needs_input2_grad = ctx->saved_data["needs_input2_grad"].toBool();
393 IndexRange var1_idx = {0, 1};
394 IndexRange var2_idx = {1, 2};
395 EXPECT_EQ(ctx->needs_input_grad(0), needs_input1_grad);
396 EXPECT_EQ(ctx->needs_input_grad(1), needs_input2_grad);
397 EXPECT_EQ(ctx->needs_input_grad({var1_idx}), needs_input1_grad);
398 EXPECT_EQ(ctx->needs_input_grad({var2_idx}), needs_input2_grad);
399 EXPECT_EQ(
400 ctx->needs_input_grad({var1_idx, var2_idx}),
401 needs_input1_grad || needs_input2_grad);
402
403 // calculate gradients
404 int mul = ctx->saved_data["mul"].toInt();
405 auto saved = ctx->get_saved_variables();
406 auto var1 = saved[0];
407 auto var2 = saved[1];
408
409 Variable grad_var1, grad_var2;
410 if (ctx->needs_input_grad(0)) {
411 grad_var1 = grad_output[0] + grad_output[0] * var2;
412 }
413 if (ctx->needs_input_grad(1)) {
414 grad_var2 = grad_output[0] * mul + grad_output[0] * var1;
415 }
416 variable_list output = {
417 grad_var1,
418 grad_var2,
419 Variable(),
420 Variable(),
421 Variable(),
422 };
423 return output;
424 }
425 };
426
427 Variable x = torch::randn({5, 5}, torch::requires_grad());
428 Variable y = torch::randn({5, 5}, torch::requires_grad());
429 auto go = torch::ones_like(x);
430 Variable out;
431
432 // grad_x
433 out = MyFunction::apply(
434 x,
435 y,
436 2,
437 /* needs_input1_grad= */ true,
438 /* needs_input2_grad= */ false);
439 auto grad_x = torch::autograd::grad({out}, {x}, {go})[0];
440 ASSERT_VARIABLE_EQ(grad_x, y + torch::ones({5, 5}));
441
442 // grad_y
443 out = MyFunction::apply(
444 x,
445 y,
446 2,
447 /* needs_input1_grad= */ false,
448 /* needs_input2_grad= */ true);
449 auto grad_y = torch::autograd::grad({out}, {y}, {go})[0];
450 ASSERT_VARIABLE_EQ(grad_y, x + torch::ones({5, 5}) * 2);
451
452 // grad_x and grad_y
453 out = MyFunction::apply(
454 x,
455 y,
456 2,
457 /* needs_input1_grad= */ true,
458 /* needs_input2_grad= */ true);
459 auto grads = torch::autograd::grad({out}, {x, y}, {go});
460 ASSERT_VARIABLE_EQ(grads[0], y + torch::ones({5, 5}));
461 ASSERT_VARIABLE_EQ(grads[1], x + torch::ones({5, 5}) * 2);
462 }
463
TEST(CustomAutogradTest,FunctionReturnsInput)464 TEST(CustomAutogradTest, FunctionReturnsInput) {
465 struct MyFunction : public Function<MyFunction> {
466 static Variable forward(AutogradContext* ctx, Variable var1) {
467 return var1;
468 }
469
470 static variable_list backward(
471 AutogradContext* ctx,
472 variable_list grad_output) {
473 return {grad_output[0] * 2};
474 }
475 };
476
477 Variable x(torch::ones(1, torch::requires_grad()));
478 MyFunction::apply(x).backward(torch::ones(1), true, true);
479 ASSERT_VARIABLE_EQ(x.grad(), torch::full(1, 2.));
480 }
481
TEST(CustomAutogradTest,FunctionReturnsUndefined)482 TEST(CustomAutogradTest, FunctionReturnsUndefined) {
483 struct MyFunction : public Function<MyFunction> {
484 static Variable forward(AutogradContext* ctx, Variable var) {
485 return var * 2;
486 }
487
488 static variable_list backward(
489 AutogradContext* ctx,
490 variable_list grad_output) {
491 at::Tensor undefined_tensor;
492 return {undefined_tensor};
493 }
494 };
495
496 auto x = torch::ones(1, torch::requires_grad());
497
498 MyFunction::apply(x).backward();
499 ASSERT_FALSE(x.grad().defined());
500
501 MyFunction::apply(x.pow(2)).backward();
502 ASSERT_FALSE(x.grad().defined());
503
504 MyFunction::apply(x).sum().backward();
505 ASSERT_FALSE(x.grad().defined());
506
507 ASSERT_FALSE(torch::autograd::grad(
508 {MyFunction::apply(x)}, {x}, {}, false, false, true)[0]
509 .defined());
510 }
511
TEST(CustomAutogradTest,MaterializeGrads)512 TEST(CustomAutogradTest, MaterializeGrads) {
513 struct MyFunction : public Function<MyFunction> {
514 static Variable forward(AutogradContext* ctx, Variable var) {
515 return var;
516 }
517
518 static variable_list backward(
519 AutogradContext* ctx,
520 variable_list grad_output) {
521 EXPECT_VARIABLE_EQ(grad_output[0], torch::zeros(1));
522 return grad_output;
523 }
524 };
525
526 auto x = torch::ones(1, torch::requires_grad());
527 UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
528 }
529
TEST(CustomAutogradTest,DontMaterializeGrads)530 TEST(CustomAutogradTest, DontMaterializeGrads) {
531 struct MyFunction : public Function<MyFunction> {
532 static Variable forward(AutogradContext* ctx, Variable var) {
533 ctx->set_materialize_grads(false);
534 return var;
535 }
536
537 static variable_list backward(
538 AutogradContext* ctx,
539 variable_list grad_output) {
540 EXPECT_FALSE(grad_output[0].defined());
541 return grad_output;
542 }
543 };
544
545 auto x = torch::ones(1, torch::requires_grad());
546 UndefinedGrad().apply({MyFunction::apply(x)})[0].backward();
547 }
548
TEST(CustomAutogradTest,NoGradCustomFunction)549 TEST(CustomAutogradTest, NoGradCustomFunction) {
550 // Custom Function should respect grad mode
551 struct MyOp : public Function<MyOp> {
552 static Variable forward(AutogradContext* ctx, Variable x) {
553 return x + 1;
554 }
555
556 static variable_list backward(AutogradContext* ctx, variable_list dy) {
557 return dy;
558 }
559 };
560
561 auto x = torch::ones({5, 5}, torch::requires_grad());
562 {
563 at::NoGradGuard no_grad;
564 auto y = MyOp::apply(x);
565 ASSERT_FALSE(y.requires_grad());
566 }
567 }
568
TEST(CustomAutogradTest,MarkDirty)569 TEST(CustomAutogradTest, MarkDirty) {
570 struct MyFunction : public Function<MyFunction> {
571 static Variable forward(AutogradContext* ctx, Variable v) {
572 // Change the value inplace
573 auto v_data = v.data_ptr<float>();
574 v_data[0] = 2;
575 ctx->mark_dirty({v});
576 return v;
577 }
578
579 static variable_list backward(
580 AutogradContext* ctx,
581 variable_list grad_output) {
582 return {(grad_output[0] * 2.0)};
583 }
584 };
585
586 // Clone here because modifying leafs inplace is not allowed
587 auto x = torch::randn({5, 5}, torch::requires_grad()).clone();
588 auto version_before = x._version();
589 auto out = MyFunction::apply(x);
590 auto version_after = x._version();
591 ASSERT_TRUE(version_after >= (version_before + 1));
592 out.sum().backward();
593 }
594
TEST(CustomAutogradTest,MarkNonDifferentiable)595 TEST(CustomAutogradTest, MarkNonDifferentiable) {
596 struct MyFunction : public Function<MyFunction> {
597 static Variable forward(AutogradContext* ctx, Variable v) {
598 Variable output = v > 0;
599 ctx->mark_non_differentiable({output});
600 return output;
601 }
602
603 static variable_list backward(
604 AutogradContext* ctx,
605 variable_list grad_output) {
606 return {(grad_output[0] * 0.0)};
607 }
608 };
609
610 auto x = torch::randn({5, 5}, torch::requires_grad());
611 auto mask = MyFunction::apply(x);
612 ASSERT_FALSE(mask.requires_grad());
613 auto y = x.masked_fill(mask, 0);
614 y.sum().backward();
615 }
616
TEST(CustomAutogradTest,MarkNonDifferentiableMixed)617 TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
618 struct MyFunction : public Function<MyFunction> {
619 static variable_list forward(AutogradContext* ctx, Variable input) {
620 Variable a = input + 1;
621 Variable b = input + 2;
622 ctx->mark_non_differentiable({a});
623 return {a, b};
624 }
625
626 static variable_list backward(
627 AutogradContext* ctx,
628 variable_list grad_output) {
629 const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
630 EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5, 5}));
631 EXPECT_VARIABLE_EQ(grad_b, torch::ones({5, 5}));
632 return {grad_b};
633 }
634 };
635
636 auto x = torch::randn({5, 5}, torch::requires_grad());
637 auto out = MyFunction::apply(x);
638
639 ASSERT_FALSE(out[0].requires_grad());
640 ASSERT_TRUE(out[1].requires_grad());
641 out[1].sum().backward();
642 ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5, 5}));
643 }
644
TEST(CustomAutogradTest,MarkNonDifferentiableNone)645 TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
646 struct MyFunction : public Function<MyFunction> {
647 static Variable forward(AutogradContext* ctx, Variable input) {
648 auto output = input.clone();
649 ctx->mark_non_differentiable({output});
650 return output;
651 }
652
653 static variable_list backward(
654 AutogradContext* ctx,
655 variable_list grad_outputs) {
656 return {};
657 }
658 };
659
660 auto x = torch::randn({5, 5}, torch::requires_grad());
661 auto r = MyFunction::apply(x * x);
662 (r * x).sum().backward();
663 }
664
TEST(CustomAutogradTest,ReturnLeafInplace)665 TEST(CustomAutogradTest, ReturnLeafInplace) {
666 struct Inplace : public Function<Inplace> {
667 static variable_list forward(AutogradContext* ctx, Variable a, Variable b) {
668 ctx->mark_dirty({a});
669 return {a.add_(b), b + 2};
670 }
671
672 static variable_list backward(
673 AutogradContext* ctx,
674 variable_list grad_output) {
675 return {grad_output[0], grad_output[0] + grad_output[1]};
676 }
677 };
678
679 Variable x = torch::randn({5, 5});
680 Variable y = torch::randn({5, 5}, torch::requires_grad());
681
682 auto out = Inplace::apply(x, y);
683 auto& q = out[0];
684 ASSERT_TRUE(torch::equal(q, x));
685 ASSERT_TRUE(q.requires_grad());
686 q.sum().backward();
687 ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5, 5}));
688 }
689
TEST(CustomAutogradTest,ReturnDuplicateInplace)690 TEST(CustomAutogradTest, ReturnDuplicateInplace) {
691 struct DoubleInplace : public Function<DoubleInplace> {
692 static variable_list forward(AutogradContext* ctx, Variable x) {
693 x.mul_(2);
694 ctx->mark_dirty({x});
695 return {x, x};
696 }
697
698 static variable_list backward(
699 AutogradContext* ctsx,
700 variable_list grad_outputs) {
701 return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
702 }
703 };
704
705 auto x = torch::randn({5, 5}, torch::requires_grad());
706
707 ASSERT_THROWS_WITH(
708 DoubleInplace::apply(x), "leaf Variable that requires grad");
709 // TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one
710 // output");
711
712 auto out = DoubleInplace::apply(x.clone());
713 ASSERT_TRUE(torch::equal(out[0], out[1]));
714 }
715
TEST(CustomAutogradTest,ReturnDuplicate)716 TEST(CustomAutogradTest, ReturnDuplicate) {
717 struct DoubleDuplicate : public Function<DoubleDuplicate> {
718 static variable_list forward(AutogradContext* ctx, Variable x) {
719 auto output = x * 2;
720 return {output, output};
721 }
722
723 static variable_list backward(
724 AutogradContext* ctx,
725 variable_list grad_outputs) {
726 return {grad_outputs[0] * 2 + grad_outputs[1] * 2};
727 }
728 };
729
730 auto x = torch::randn({5, 5}, torch::requires_grad());
731 auto out = DoubleDuplicate::apply(x);
732 ASSERT_TRUE(torch::equal(out[0], out[1]));
733 }
734
TEST(CustomAutogradTest,SaveEmptyForBackward)735 TEST(CustomAutogradTest, SaveEmptyForBackward) {
736 struct MyFunction : public Function<MyFunction> {
737 static Variable forward(AutogradContext* ctx, Variable input) {
738 ctx->save_for_backward({Variable(), input, Variable()});
739 return input * input;
740 }
741
742 static variable_list backward(
743 AutogradContext* ctx,
744 variable_list grad_output) {
745 auto saved = ctx->get_saved_variables();
746 EXPECT_FALSE(saved[0].defined());
747 EXPECT_FALSE(saved[2].defined());
748 return {saved[1] * 2 * grad_output[0]};
749 }
750 };
751
752 Variable x = torch::randn({5, 5}, torch::requires_grad());
753 auto y = MyFunction::apply(x);
754 y.sum().backward();
755 ASSERT_VARIABLE_EQ(x.grad(), 2 * x);
756 }
757
TEST(CustomAutogradTest,InvalidGradients)758 TEST(CustomAutogradTest, InvalidGradients) {
759 struct MyFunction : public Function<MyFunction> {
760 static Variable forward(AutogradContext* ctx, Variable x) {
761 return x * 2;
762 }
763
764 static variable_list backward(
765 AutogradContext* ctsx,
766 variable_list grad_outputs) {
767 return {
768 torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
769 }
770 };
771
772 auto input1 =
773 torch::randn({5, 5}, torch::dtype(torch::kFloat).requires_grad(true));
774 ASSERT_THROWS_WITH(
775 MyFunction::apply(input1).sum().backward(), "expected shape");
776 auto input2 =
777 torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
778 }
779
TEST(CustomAutogradTest,NoGradInput)780 TEST(CustomAutogradTest, NoGradInput) {
781 struct MyFunction : public Function<MyFunction> {
782 static Variable forward(AutogradContext*, Variable x) {
783 return x;
784 }
785
786 static variable_list backward(
787 AutogradContext*,
788 variable_list grad_outputs) {
789 return grad_outputs;
790 }
791 };
792
793 Variable x = torch::randn({5, 5}, torch::requires_grad());
794 Variable y;
795 {
796 at::NoGradGuard no_grad;
797 y = MyFunction::apply(x);
798 }
799
800 ASSERT_TRUE(x.requires_grad());
801 ASSERT_FALSE(y.grad_fn());
802 }
803
TEST(CustomAutogradTest,TooManyGrads)804 TEST(CustomAutogradTest, TooManyGrads) {
805 struct MyFunction : public Function<MyFunction> {
806 static Variable forward(AutogradContext*, Variable input) {
807 return input;
808 }
809
810 static variable_list backward(AutogradContext*, variable_list grad_output) {
811 grad_output.insert(grad_output.end(), {Variable(), Variable()});
812 return grad_output;
813 }
814 };
815 }
816
TEST(CustomAutogradTest,DepNoGrad)817 TEST(CustomAutogradTest, DepNoGrad) {
818 struct F1 : public Function<F1> {
819 static variable_list forward(AutogradContext* ctx, Variable input) {
820 auto out = torch::randn(input.sizes());
821 ctx->mark_non_differentiable({out});
822 return {input, out};
823 }
824
825 static variable_list backward(
826 AutogradContext* ctx,
827 variable_list grad_output) {
828 return {grad_output[0]};
829 }
830 };
831
832 struct F2 : public Function<F2> {
833 static Variable forward(AutogradContext*, Variable input, Variable ignore) {
834 return input;
835 }
836
837 static variable_list backward(AutogradContext*, variable_list grad_output) {
838 return {grad_output[0], Variable()};
839 }
840 };
841
842 auto x = torch::randn(5, torch::requires_grad());
843 auto out = F1::apply(x);
844 Variable &a = out[0], &b = out[1];
845 b = b + 1; // Separate F1 and F2 by another operation
846 ASSERT_TRUE(a.requires_grad());
847 ASSERT_FALSE(b.requires_grad());
848
849 auto c = F2::apply(a, b);
850 c.backward(torch::ones(c.sizes()), false, false);
851 ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
852 }
853
TEST(CustomAutogradTest,Reentrant)854 TEST(CustomAutogradTest, Reentrant) {
855 static Variable y_data = torch::randn({2, 2});
856 struct Reenter : public Function<Reenter> {
857 static Variable forward(AutogradContext* ctx, Variable input) {
858 Variable output;
859 {
860 at::AutoGradMode enable_grad(true);
861 auto x = make_variable(input.tensor_data(), true);
862 auto y = make_variable(y_data.tensor_data(), true);
863 output = x * y;
864
865 ctx->saved_data["x"] = x;
866 ctx->saved_data["y"] = y;
867 ctx->saved_data["output_var"] = output;
868 }
869 return output.detach();
870 }
871
872 static variable_list backward(
873 AutogradContext* ctx,
874 variable_list grad_output) {
875 {
876 at::AutoGradMode enable_grad(true);
877 auto out = ctx->saved_data["output_var"].toTensor();
878 out.sum().backward();
879 }
880 return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
881 }
882 };
883
884 auto x = torch::randn({2, 2}, torch::requires_grad());
885 auto out = Reenter::apply(x);
886 out.sum().backward();
887 ASSERT_VARIABLE_EQ(x.grad(), y_data);
888 }
889
890 // NOTE: If this fails for apparently unrelated reasons in TSAN be aware of
891 // the TSAN limit on mutex: https://github.com/google/sanitizers/issues/950
TEST(CustomAutogradTest,DeepReentrant)892 TEST(CustomAutogradTest, DeepReentrant) {
893 struct DeepReenter : public Function<DeepReenter> {
894 static Variable forward(AutogradContext* ctx, Variable x) {
895 {
896 at::AutoGradMode enable_grad(true);
897 ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
898 }
899 return ctx->saved_data["x"].toTensor().detach();
900 }
901
902 static variable_list backward(
903 AutogradContext* ctx,
904 variable_list grad_output) {
905 if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
906 return grad_output;
907 }
908 {
909 at::AutoGradMode enable_grad(true);
910 apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
911 return grad_output;
912 }
913 }
914 };
915
916 // This should not stack overflow
917 auto v =
918 torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
919 DeepReenter::apply(v).sum().backward();
920 }
921
TEST(CustomAutogradTest,ReentrantPriority)922 TEST(CustomAutogradTest, ReentrantPriority) {
923 static std::vector<int> order;
924
925 struct MyFunction : public Function<MyFunction> {
926 static Variable forward(AutogradContext*, Variable x) {
927 return x;
928 }
929
930 static variable_list backward(AutogradContext*, variable_list grad) {
931 order.push_back(0);
932 return grad;
933 }
934 };
935
936 struct Reenter : public Function<Reenter> {
937 static Variable forward(AutogradContext* ctx, Variable x) {
938 {
939 at::AutoGradMode enable_grad(true);
940 ctx->saved_data["x"] = make_variable(x.tensor_data(), true) - 1;
941 }
942 return ctx->saved_data["x"].toTensor().detach();
943 }
944
945 static variable_list backward(
946 AutogradContext* ctx,
947 variable_list grad_output) {
948 order.push_back(1);
949 if (!at::native::is_nonzero(ctx->saved_data["x"].toTensor())) {
950 return grad_output;
951 }
952 {
953 at::AutoGradMode enable_grad(true);
954 apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
955 return grad_output;
956 }
957 }
958 };
959
960 auto a = MyFunction::apply(
961 torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
962 auto b = Reenter::apply(
963 torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
964 auto v = a * b;
965 v.backward();
966
967 // All the reentrant tasks should be prioritized over the MyFunction backward
968 // task.
969 ASSERT_EQ(order.size(), 10);
970 ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
971 ASSERT_EQ(order.back(), 0);
972 // Clear static variable in case test get executed in a loop
973 order.clear();
974 }
975
TEST(CustomAutogradTest,Hooks)976 TEST(CustomAutogradTest, Hooks) {
977 Variable x = torch::ones({5, 5}, torch::requires_grad());
978 Variable y = torch::ones({5, 5}) * 4;
979 y.set_requires_grad(true);
980
981 int counter = 0;
982
983 std::function<void(int, Variable)> bw_hook(
984 [&counter](int inc, Variable grad) { counter += inc; });
985
986 Variable z = x * x + x * 2 + x * y + y;
987 x.register_hook([&bw_hook](Variable grad) { bw_hook(0, grad); });
988 auto hook_1 =
989 z.register_hook([&bw_hook](Variable grad) { bw_hook(1, grad); });
990 z.backward(torch::ones({5, 5}), true, true);
991 ASSERT_EQ(counter, 1);
992
993 auto hook_2 =
994 z.register_hook([&bw_hook](Variable grad) { bw_hook(2, grad); });
995 z.backward(torch::ones({5, 5}), true, true);
996 ASSERT_EQ(counter, 4);
997
998 z.remove_hook(hook_2);
999 z.backward(torch::ones({5, 5}), true, true);
1000 ASSERT_EQ(counter, 5);
1001
1002 std::function<Variable(Variable)> bw_hook_modify(
1003 [](Variable grad) { return grad.mul(2); });
1004
1005 z.remove_hook(hook_1);
1006 z.register_hook(bw_hook_modify);
1007 y.grad().zero_();
1008 z.backward(torch::ones({5, 5}), true, false);
1009 ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 2);
1010
1011 y.register_hook(bw_hook_modify);
1012 y.grad().zero_();
1013 z.backward(torch::ones({5, 5}), false, false);
1014 ASSERT_VARIABLE_EQ(y.grad(), (x + 1) * 4);
1015
1016 ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
1017 }
1018
TEST(CustomAutogradTest,HooksInplace)1019 TEST(CustomAutogradTest, HooksInplace) {
1020 auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1021
1022 int hook1_count = 0;
1023 auto hook1 = ([&hook1_count](Variable grad) {
1024 hook1_count++;
1025 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1026 });
1027
1028 int hook2_count = 0;
1029 auto hook2 = ([&hook2_count](Variable grad) {
1030 hook2_count++;
1031 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1032 });
1033
1034 a.register_hook(hook1);
1035 a.mul_(2);
1036 a.register_hook(hook2);
1037
1038 auto out = (a + 1).sum();
1039 out.backward();
1040
1041 ASSERT_EQ(hook1_count, 1);
1042 ASSERT_EQ(hook2_count, 1);
1043 }
1044
TEST(CustomAutogradTest,HooksInplaceWithRetainsGrad)1045 TEST(CustomAutogradTest, HooksInplaceWithRetainsGrad) {
1046 auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1047
1048 int hook1_count = 0;
1049 auto hook1 = ([&hook1_count](Variable grad) {
1050 hook1_count++;
1051 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1052 });
1053
1054 int hook2_count = 0;
1055 auto hook2 = ([&hook2_count](Variable grad) {
1056 hook2_count++;
1057 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 2);
1058 });
1059
1060 int hook3_count = 0;
1061 auto hook3 = ([&hook3_count](Variable grad) {
1062 hook3_count++;
1063 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1064 });
1065
1066 a.register_hook(hook1);
1067 a.retain_grad();
1068 a.register_hook(hook2);
1069
1070 a.mul_(2);
1071 a.register_hook(hook3);
1072
1073 auto out = (a + 1).sum();
1074 out.backward();
1075
1076 ASSERT_EQ(hook1_count, 1);
1077 ASSERT_EQ(hook2_count, 1);
1078 ASSERT_EQ(hook3_count, 1);
1079
1080 ASSERT_TRUE(a.retains_grad());
1081 ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1082 }
1083
TEST(CustomAutogradTest,HooksInplaceTwiceWithRetainsGrad)1084 TEST(CustomAutogradTest, HooksInplaceTwiceWithRetainsGrad) {
1085 auto a = torch::ones({5, 5}, torch::requires_grad()).clone();
1086
1087 int hook1_count = 0;
1088 auto hook1 = ([&hook1_count](Variable grad) {
1089 hook1_count++;
1090 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1091 });
1092
1093 int hook2_count = 0;
1094 auto hook2 = ([&hook2_count](Variable grad) {
1095 hook2_count++;
1096 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}) * 4);
1097 });
1098
1099 int hook3_count = 0;
1100 auto hook3 = ([&hook3_count](Variable grad) {
1101 hook3_count++;
1102 ASSERT_VARIABLE_EQ(grad, torch::ones({5, 5}));
1103 });
1104
1105 a.register_hook(hook1);
1106 a.retain_grad();
1107 a.register_hook(hook2);
1108
1109 a.mul_(2);
1110 a.mul_(2);
1111 a.register_hook(hook3);
1112
1113 auto out = (a + 1).sum();
1114 out.backward();
1115
1116 ASSERT_EQ(hook1_count, 1);
1117 ASSERT_EQ(hook2_count, 1);
1118 ASSERT_EQ(hook3_count, 1);
1119
1120 ASSERT_TRUE(a.retains_grad());
1121 ASSERT_VARIABLE_EQ(a.grad(), torch::ones({5, 5}));
1122 }
1123
TEST(CustomAutogradTest,HookNone)1124 TEST(CustomAutogradTest, HookNone) {
1125 struct NoneGradientFunction : public Function<NoneGradientFunction> {
1126 static variable_list forward(AutogradContext* ctx, Variable x, Variable y) {
1127 return {x, y};
1128 }
1129
1130 static variable_list backward(AutogradContext* ctx, variable_list grad) {
1131 return {grad[0], Variable()};
1132 }
1133 };
1134
1135 bool was_called = false;
1136
1137 auto hook = ([&was_called](Variable grad) {
1138 ASSERT_TRUE(grad.defined());
1139 was_called = true;
1140 });
1141
1142 auto x = torch::randn({5, 5}, torch::requires_grad());
1143 auto y = torch::randn({5, 5});
1144
1145 auto out = NoneGradientFunction::apply(x, y);
1146 Variable rx = x[0], ry = x[1];
1147
1148 rx.register_hook(hook);
1149 ry.register_hook(hook);
1150 (rx + ry).sum().backward();
1151 ASSERT_TRUE(was_called);
1152 }
1153
TEST(CustomAutogradTest,BackwardWithInputs)1154 TEST(CustomAutogradTest, BackwardWithInputs) {
1155 Variable x = torch::randn({5, 5}, torch::requires_grad());
1156 Variable y = torch::randn({5, 5}, torch::requires_grad());
1157 Variable z = x * x + x * y + y * y;
1158 Variable x_grad_expected = 2 * x + y;
1159 Variable y_grad_expected = x + 2 * y;
1160
1161 z.backward(torch::ones({5, 5}), false, false, {x});
1162
1163 ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1164 ASSERT_FALSE(y.grad().defined());
1165 }
1166
TEST(CustomAutogradTest,BackwardWithEmptyInputs)1167 TEST(CustomAutogradTest, BackwardWithEmptyInputs) {
1168 Variable x = torch::randn({5, 5}, torch::requires_grad());
1169 Variable y = torch::randn({5, 5}, torch::requires_grad());
1170 Variable z = x * x + x * y + y * y;
1171 Variable x_grad_expected = 2 * x + y;
1172 Variable y_grad_expected = x + 2 * y;
1173 ASSERT_THROWS_WITH(
1174 z.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{}),
1175 "cannot be empty");
1176 }
1177
TEST(CustomAutogradTest,BackwardWithNonLeafInputs)1178 TEST(CustomAutogradTest, BackwardWithNonLeafInputs) {
1179 Variable x = torch::randn({5, 5}, torch::requires_grad());
1180 Variable y = torch::randn({5, 5}, torch::requires_grad());
1181 Variable z = x * x;
1182 Variable w = y * z + x * y + y * y;
1183
1184 Variable x_grad_expected = 2 * x * y + y;
1185 Variable z_grad_expected = y;
1186
1187 w.backward(torch::ones({5, 5}), false, false, std::vector<Variable>{x, z});
1188
1189 ASSERT_VARIABLE_EQ(x.grad(), x_grad_expected);
1190 ASSERT_VARIABLE_EQ(z.grad(), z_grad_expected);
1191 ASSERT_FALSE(y.grad().defined());
1192 }
1193
TEST(CustomAutogradTest,BackwardWithCreateGraphWarns)1194 TEST(CustomAutogradTest, BackwardWithCreateGraphWarns) {
1195 c10::WarningUtils::WarnAlways guard(true);
1196
1197 torch::Tensor x = torch::randn({5, 5}).set_requires_grad(true);
1198 auto z = x * x;
1199 {
1200 WarningCapture warnings;
1201 z.backward(torch::ones({5, 5}), std::nullopt, true);
1202 ASSERT_TRUE(
1203 warnings.str().find("Using backward() with create_graph=True") !=
1204 std::string::npos);
1205 }
1206
1207 {
1208 WarningCapture warnings;
1209 torch::autograd::backward({z}, {torch::ones({5, 5})}, std::nullopt, true);
1210 ASSERT_TRUE(
1211 warnings.str().find("Using backward() with create_graph=True") !=
1212 std::string::npos);
1213 }
1214 }
1215
1216 /**
1217 * Tests for AutogradNotImplementedFallback
1218 * - Check that we created the NotImplemented kernel when inputs require grad
1219 * but when no inputs require grad, we should not create this node
1220 * - check_inplace logic
1221 * - view ops
1222 * - TODO: Tests for debug-only checks? Don't need for now because CI doesn't
1223 * test non-NDEBUG builds.
1224 * - tensorlist input and output
1225 * - multiple outputs / non-tensor output
1226 * - rebase_history vs set_history
1227 */
1228 namespace {
1229
inplace_op(const torch::Tensor & self,const torch::Tensor & other)1230 torch::Tensor inplace_op(
1231 const torch::Tensor& self,
1232 const torch::Tensor& other) {
1233 return self.add_(other);
1234 }
1235
two_arg_inplace_op(const torch::Tensor & self,const torch::Tensor & other)1236 std::tuple<torch::Tensor, torch::Tensor> two_arg_inplace_op(
1237 const torch::Tensor& self,
1238 const torch::Tensor& other) {
1239 other.add_(self);
1240 self.add_(other);
1241 return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1242 }
1243
two_pairs_of_view_op(const torch::Tensor & self,const torch::Tensor & other)1244 std::tuple<torch::Tensor, torch::Tensor> two_pairs_of_view_op(
1245 const torch::Tensor& self,
1246 const torch::Tensor& other) {
1247 // This is not allowed. We test below that this calling into the boxed kernel
1248 // will raise an error
1249 return std::tuple<torch::Tensor, torch::Tensor>(self, other);
1250 }
1251
non_first_view_op(const torch::Tensor & self,const torch::Tensor & other)1252 std::tuple<torch::Tensor, torch::Tensor> non_first_view_op(
1253 const torch::Tensor& self,
1254 const torch::Tensor& other) {
1255 // This is not allowed. We test below that this calling into the boxed kernel
1256 // will raise an error
1257 return std::tuple<torch::Tensor, torch::Tensor>(self.clone(), other);
1258 }
1259
ret_single_non_tensor(const torch::Tensor & self,const torch::Tensor & other)1260 int64_t ret_single_non_tensor(
1261 const torch::Tensor& self,
1262 const torch::Tensor& other) {
1263 return 12;
1264 }
1265
opt_op(const torch::Tensor & self,const std::optional<at::Tensor> & other)1266 torch::Tensor opt_op(
1267 const torch::Tensor& self,
1268 const std::optional<at::Tensor>& other) {
1269 if (other.has_value()) {
1270 return self + other.value();
1271 } else {
1272 return self.clone();
1273 }
1274 }
1275
my_custom_op(const torch::Tensor & self,const torch::Tensor & other)1276 torch::Tensor my_custom_op(
1277 const torch::Tensor& self,
1278 const torch::Tensor& other) {
1279 return self + other;
1280 }
1281
ret_tuple_non_tensor(const torch::Tensor & self,const torch::Tensor & other)1282 std::tuple<torch::Tensor, torch::Tensor, int64_t> ret_tuple_non_tensor(
1283 const torch::Tensor& self,
1284 const torch::Tensor& other) {
1285 auto a = self - other;
1286 auto b = self + other;
1287 return std::tuple<torch::Tensor, torch::Tensor, int64_t>(a, b, 12);
1288 }
1289
view_op(const torch::Tensor & self)1290 torch::Tensor view_op(const torch::Tensor& self) {
1291 return self.alias();
1292 }
1293
view_op_with_extra_arg(const torch::Tensor & self,const torch::Tensor & other)1294 torch::Tensor view_op_with_extra_arg(
1295 const torch::Tensor& self,
1296 const torch::Tensor& other) {
1297 return self.alias();
1298 }
1299
ret_tensor_vector_view(const torch::Tensor & self,const torch::Tensor & other)1300 std::vector<torch::Tensor> ret_tensor_vector_view(
1301 const torch::Tensor& self,
1302 const torch::Tensor& other) {
1303 return {self.alias(), self.alias()};
1304 }
1305
ret_tensor_vector(const torch::Tensor & self,const torch::Tensor & other)1306 std::vector<at::Tensor> ret_tensor_vector(
1307 const torch::Tensor& self,
1308 const torch::Tensor& other) {
1309 std::vector<at::Tensor> out;
1310 out.push_back(self + other);
1311 out.push_back(self - other);
1312 return out;
1313 }
1314
tensorlist_op(const torch::Tensor & self,at::TensorList other)1315 torch::Tensor tensorlist_op(const torch::Tensor& self, at::TensorList other) {
1316 const auto& res = self.clone();
1317 for (const auto& t : other) {
1318 res.add_(t);
1319 }
1320 return res;
1321 }
1322
1323 #define REGISTER_TEST_OP(name, schema, fn) \
1324 auto m = MAKE_TORCH_LIBRARY(_test); \
1325 m.def(schema); \
1326 auto m_autograd = MAKE_TORCH_LIBRARY_IMPL(_test, Autograd); \
1327 auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU); \
1328 auto m_inplaceorview = MAKE_TORCH_LIBRARY_IMPL(_test, ADInplaceOrView); \
1329 m_cpu.impl(name, c10::DispatchKey::CPU, TORCH_FN(fn)); \
1330 m_autograd.impl( \
1331 name, c10::DispatchKey::Autograd, autogradNotImplementedFallback()); \
1332 m_inplaceorview.impl( \
1333 name, \
1334 c10::DispatchKey::ADInplaceOrView, \
1335 autogradNotImplementedInplaceOrViewFallback());
1336
1337 template <typename F>
assertBasicChecks(F op)1338 void assertBasicChecks(F op) {
1339 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1340 auto b = torch::tensor({1.}, {torch::kFloat32});
1341 auto c = torch::tensor({1.}, {torch::kFloat32});
1342
1343 // If any inputs require grad,
1344 auto out1 = op(a, b);
1345 ASSERT_THROWS_WITH(out1.backward(), "is not implemented");
1346
1347 // # Should not have grad_fn if none require grad
1348 auto out2 = op(b, c);
1349 ASSERT_THROWS_WITH(
1350 out2.backward(),
1351 "element 0 of tensors does not require grad and does not have a grad_fn");
1352
1353 // TODO: Forward AD Tests?
1354 }
1355
1356 } // namespace
1357
TEST(TestAutogradNotImplementedFallback,RetSingleNonTensor)1358 TEST(TestAutogradNotImplementedFallback, RetSingleNonTensor) {
1359 REGISTER_TEST_OP(
1360 "ret_single_non_tensor",
1361 "_test::ret_single_non_tensor(Tensor self, Tensor other) -> int",
1362 ret_single_non_tensor);
1363 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1364 "_test::ret_single_non_tensor", "");
1365 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1366 return callOpUnboxed<int64_t, const torch::Tensor&, const torch::Tensor&>(
1367 opHandle, _1, _2);
1368 };
1369
1370 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1371 auto b = torch::tensor({1.}, {torch::kFloat32});
1372
1373 ASSERT_EQ(op(a, b), ret_single_non_tensor(a, b));
1374 }
1375
TEST(TestAutogradNotImplementedFallback,InplaceOp)1376 TEST(TestAutogradNotImplementedFallback, InplaceOp) {
1377 REGISTER_TEST_OP(
1378 "inplace_op",
1379 "_test::inplace_op(Tensor(a!) self, Tensor other) -> Tensor(a!)",
1380 inplace_op);
1381 auto opHandle =
1382 c10::Dispatcher::singleton().findSchemaOrThrow("_test::inplace_op", "");
1383 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1384 return callOpUnboxed<
1385 torch::Tensor,
1386 const torch::Tensor&,
1387 const torch::Tensor&>(opHandle, _1, _2);
1388 };
1389
1390 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1391 auto b = torch::tensor({1.}, {torch::kFloat32});
1392
1393 // Check in-place
1394 ASSERT_THROWS_WITH(
1395 op(a, b),
1396 "a leaf Variable that requires grad is being used in an in-place operation");
1397 op(b, a);
1398 a = a.clone();
1399 b = b.clone();
1400 auto c = op(a, b);
1401 ASSERT_TRUE(torch::allclose(c, inplace_op(a, b)));
1402
1403 // Test in-place on view
1404 auto base =
1405 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1406 auto view = base.view(-1);
1407 auto t = torch::tensor({1.}, {torch::kFloat32});
1408
1409 torch::Tensor v_nograd;
1410 {
1411 c10::NoGradGuard guard;
1412 v_nograd = base.view(-1);
1413 op(v_nograd, t);
1414 }
1415
1416 ASSERT_THROWS_WITH(op(v_nograd, t), "A view was created in no_grad mode");
1417 ASSERT_EQ(op(view, t).unsafeGetTensorImpl(), view.unsafeGetTensorImpl());
1418 ASSERT_THAT(
1419 op(view, t).grad_fn()->name(), ::testing::HasSubstr("AsStridedBackward"));
1420 }
1421
TEST(TestAutogradNotImplementedFallback,DoubleInplaceOp)1422 TEST(TestAutogradNotImplementedFallback, DoubleInplaceOp) {
1423 REGISTER_TEST_OP(
1424 "two_arg_inplace_op",
1425 "_test::two_arg_inplace_op(Tensor(a!) self, Tensor(b!) other) -> (Tensor(a!), Tensor(b!))",
1426 two_arg_inplace_op);
1427 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1428 "_test::two_arg_inplace_op", "");
1429 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1430 return callOpUnboxed<
1431 std::tuple<torch::Tensor, torch::Tensor>,
1432 const torch::Tensor&,
1433 const torch::Tensor&>(opHandle, _1, _2);
1434 };
1435 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1436 auto b = torch::tensor({1.}, {torch::kFloat32});
1437
1438 // Both are modified in-place!
1439 ASSERT_THROWS_WITH(
1440 op(a, b),
1441 "a leaf Variable that requires grad is being used in an in-place operation");
1442 ASSERT_THROWS_WITH(
1443 op(b, a),
1444 "a leaf Variable that requires grad is being used in an in-place operation");
1445
1446 auto c =
1447 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1448 auto d =
1449 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1450
1451 auto saved_version_c = c._version();
1452 auto saved_version_d = d._version();
1453 op(c, d);
1454 ASSERT_NE(c._version(), saved_version_c);
1455 ASSERT_NE(d._version(), saved_version_d);
1456 }
1457
TEST(TestAutogradNotImplementedFallback,OptOp)1458 TEST(TestAutogradNotImplementedFallback, OptOp) {
1459 REGISTER_TEST_OP(
1460 "opt_op", "_test::opt_op(Tensor self, Tensor? other) -> Tensor", opt_op);
1461 auto opHandle =
1462 c10::Dispatcher::singleton().findSchemaOrThrow("_test::opt_op", "");
1463 auto op = [&](const torch::Tensor& _1,
1464 const std::optional<torch::Tensor>& _2) {
1465 return callOpUnboxed<
1466 torch::Tensor,
1467 const torch::Tensor&,
1468 const std::optional<torch::Tensor>&>(opHandle, _1, _2);
1469 };
1470
1471 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1472 auto b = torch::tensor({1.}, {torch::kFloat32});
1473
1474 ASSERT_TRUE(torch::allclose(op(a, b), opt_op(a, b)));
1475 ASSERT_TRUE(torch::allclose(op(a, {}), opt_op(a, {})));
1476 }
1477
TEST(TestAutogradNotImplementedFallback,OutOfPlaceAddition)1478 TEST(TestAutogradNotImplementedFallback, OutOfPlaceAddition) {
1479 REGISTER_TEST_OP(
1480 "my_custom_op",
1481 "_test::my_custom_op(Tensor self, Tensor other) -> Tensor",
1482 my_custom_op);
1483 auto opHandle =
1484 c10::Dispatcher::singleton().findSchemaOrThrow("_test::my_custom_op", "");
1485 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1486 return callOpUnboxed<
1487 torch::Tensor,
1488 const torch::Tensor&,
1489 const torch::Tensor&>(opHandle, _1, _2);
1490 };
1491
1492 assertBasicChecks(op);
1493 }
1494
TEST(TestAutogradNotImplementedFallback,RetTupleNonTensor)1495 TEST(TestAutogradNotImplementedFallback, RetTupleNonTensor) {
1496 REGISTER_TEST_OP(
1497 "ret_tuple_non_tensor",
1498 "_test::ret_tuple_non_tensor(Tensor self, Tensor other) -> (Tensor, Tensor, int)",
1499 ret_tuple_non_tensor);
1500 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1501 "_test::ret_tuple_non_tensor", "");
1502 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1503 auto out = callOpUnboxed<
1504 std::tuple<torch::Tensor, torch::Tensor, int64_t>,
1505 const torch::Tensor&,
1506 const torch::Tensor&>(opHandle, _1, _2);
1507 auto [out0, out1, out2] = std::move(out);
1508 return out0;
1509 };
1510
1511 assertBasicChecks(op);
1512 }
1513
TEST(TestAutogradNotImplementedFallback,ViewOp)1514 TEST(TestAutogradNotImplementedFallback, ViewOp) {
1515 REGISTER_TEST_OP(
1516 "view_op", "_test::view_op(Tensor(a) self) -> Tensor(a)", view_op);
1517 auto opHandle =
1518 c10::Dispatcher::singleton().findSchemaOrThrow("_test::view_op", "");
1519 auto op = [&](const torch::Tensor& _1) {
1520 return callOpUnboxed<torch::Tensor, const torch::Tensor&>(opHandle, _1);
1521 };
1522 auto b = torch::tensor({1.}, {torch::kFloat32});
1523 auto v = op(b);
1524 ASSERT_TRUE(v.is_view());
1525 ASSERT_EQ(v._base().unsafeGetTensorImpl(), b.unsafeGetTensorImpl());
1526
1527 auto b1 =
1528 torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true).clone();
1529 auto v1 = op(b1);
1530 ASSERT_TRUE(v1.is_view());
1531 ASSERT_EQ(v1._base().unsafeGetTensorImpl(), b1.unsafeGetTensorImpl());
1532
1533 // Test inplace on view
1534 auto t = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1535
1536 // raise on rebase_history when it refreshes grad_fn
1537 ASSERT_THROWS_WITH(
1538 v1.add_(t), "which does not have a derivative implemented is forbidden");
1539 // base should not be aware of the views, so this is still okay
1540 b1.add_(t);
1541 ASSERT_THROWS_WITH(
1542 v1.grad_fn(),
1543 "which does not have a derivative implemented is forbidden");
1544 }
1545
TEST(TestAutogradNotImplementedFallback,ViewOpWithExtraArg)1546 TEST(TestAutogradNotImplementedFallback, ViewOpWithExtraArg) {
1547 REGISTER_TEST_OP(
1548 "view_op_with_extra_arg",
1549 "_test::view_op_with_extra_arg(Tensor(a) self, Tensor other) -> Tensor(a)",
1550 view_op_with_extra_arg);
1551 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1552 "_test::view_op_with_extra_arg", "");
1553 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1554 return callOpUnboxed<
1555 torch::Tensor,
1556 const torch::Tensor&,
1557 const torch::Tensor&>(opHandle, _1, _2);
1558 };
1559 assertBasicChecks(op);
1560 auto a = torch::tensor({1.}, {torch::kFloat32});
1561 auto b = torch::tensor({2.}, {torch::kFloat32});
1562 auto out1 = op(a, b);
1563 ASSERT_TRUE(out1.is_view());
1564 ASSERT_EQ(out1._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1565 }
1566
TEST(TestAutogradNotImplementedFallback,RetTensorVectorView)1567 TEST(TestAutogradNotImplementedFallback, RetTensorVectorView) {
1568 REGISTER_TEST_OP(
1569 "ret_tensor_vector_view",
1570 "_test::ret_tensor_vector_view(Tensor(a) self, Tensor other) -> Tensor[](a)",
1571 ret_tensor_vector_view);
1572 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1573 "_test::ret_tensor_vector_view", "");
1574 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1575 return callOpUnboxed<
1576 std::vector<at::Tensor>,
1577 const torch::Tensor&,
1578 const torch::Tensor&>(opHandle, _1, _2);
1579 };
1580 auto a = torch::tensor({1.}, {torch::kFloat32});
1581 auto b = torch::tensor({1.}, {torch::kFloat32});
1582 auto out = op(a, b);
1583 ASSERT_TRUE(out[0].is_view());
1584 ASSERT_EQ(out[0]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1585 ASSERT_TRUE(out[1].is_view());
1586 ASSERT_EQ(out[1]._base().unsafeGetTensorImpl(), a.unsafeGetTensorImpl());
1587 }
1588
TEST(TestAutogradNotImplementedFallback,DoubleViewOP)1589 TEST(TestAutogradNotImplementedFallback, DoubleViewOP) {
1590 REGISTER_TEST_OP(
1591 "two_pairs_of_view_op",
1592 "_test::two_pairs_of_view_op(Tensor(a) self, Tensor(b) other) -> (Tensor(a), Tensor(b))",
1593 two_pairs_of_view_op);
1594 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1595 "_test::two_pairs_of_view_op", "");
1596 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1597 return callOpUnboxed<
1598 std::tuple<torch::Tensor, torch::Tensor>,
1599 const torch::Tensor&,
1600 const torch::Tensor&>(opHandle, _1, _2);
1601 };
1602 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1603 auto b = torch::tensor({1.}, {torch::kFloat32});
1604 ASSERT_THROWS_WITH(
1605 op(a, b),
1606 "Expected only a single output in the operator schema to have a non-write alias annotation");
1607 }
1608
TEST(TestAutogradNotImplementedFallback,NonFirstViewOP)1609 TEST(TestAutogradNotImplementedFallback, NonFirstViewOP) {
1610 REGISTER_TEST_OP(
1611 "non_first_view_op",
1612 "_test::non_first_view_op(Tensor self, Tensor(b) other) -> (Tensor, Tensor(b))",
1613 non_first_view_op);
1614 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1615 "_test::non_first_view_op", "");
1616 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1617 return callOpUnboxed<
1618 std::tuple<torch::Tensor, torch::Tensor>,
1619 const torch::Tensor&,
1620 const torch::Tensor&>(opHandle, _1, _2);
1621 };
1622 auto a = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1623 auto b = torch::tensor({1.}, {torch::kFloat32});
1624 ASSERT_THROWS_WITH(
1625 op(a, b), "can only create view relationships between the first");
1626 }
1627
TEST(TestAutogradNotImplementedFallback,RetTensorVector)1628 TEST(TestAutogradNotImplementedFallback, RetTensorVector) {
1629 REGISTER_TEST_OP(
1630 "ret_tensor_vector",
1631 "_test::ret_tensor_vector(Tensor self, Tensor other) -> Tensor[]",
1632 ret_tensor_vector);
1633 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1634 "_test::ret_tensor_vector", "");
1635 auto op = [&](const torch::Tensor& _1, const torch::Tensor& _2) {
1636 return callOpUnboxed<
1637 std::vector<at::Tensor>,
1638 const torch::Tensor&,
1639 const torch::Tensor&>(opHandle, _1, _2)[0];
1640 };
1641 assertBasicChecks(op);
1642 }
1643
TEST(TestAutogradNotImplementedFallback,TensorlistOp)1644 TEST(TestAutogradNotImplementedFallback, TensorlistOp) {
1645 REGISTER_TEST_OP(
1646 "tensorlist_op",
1647 "_test::tensorlist_op(Tensor self, Tensor[] other) -> Tensor",
1648 tensorlist_op);
1649 auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
1650 "_test::tensorlist_op", "");
1651 auto op = [&](torch::Tensor _1, at::TensorList _2) {
1652 return callOpUnboxed<torch::Tensor, const torch::Tensor&, at::TensorList>(
1653 opHandle, _1, _2);
1654 };
1655
1656 auto a = torch::tensor({1.}, {torch::kFloat32});
1657 auto b = torch::tensor({1.}, {torch::kFloat32});
1658 auto c = torch::tensor({1.}, {torch::kFloat32}).set_requires_grad(true);
1659 std::vector<torch::Tensor> vec = {b, c};
1660 auto out = op(a, vec);
1661
1662 ASSERT_THROWS_WITH(
1663 torch::autograd::grad({out}, {vec[0]}),
1664 "element 0 of the input tensors does not require grad");
1665 ASSERT_THROWS_WITH(
1666 torch::autograd::grad({out}, {vec[1]}), "is not implemented");
1667
1668 ASSERT_TRUE(at::allclose(op(a, vec), tensorlist_op(a, vec)));
1669 }
1670
1671 // TODO add these tests if needed
1672 // test_once_differentiable
1673 // test_sparse_backward
1674 // test_save_output_nr
1675 // test_free_deep_graph_pyfunction
1676 // test_naughty_anomaly_access
1677 // test_naughty_autograd-function_stashing_ctx
1678 // test_custom_autograd_repeated_grad_grad
1679 // test_return_leaf
1680 // test_anomaly_detect_nan
1681 // test_no_grad_copy
1682