xref: /aosp_15_r20/external/pytorch/test/dynamo/test_autograd_function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2# flake8: noqa: B950
3import copy
4import math
5from dataclasses import dataclass
6
7import torch
8import torch._dynamo.test_case
9import torch._dynamo.testing
10import torch._dynamo.utils
11from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda
12
13
14if HAS_CUDA:
15    import triton
16
17    from torch.testing._internal.triton_utils import add_kernel
18
19
20class CustomFunc1(torch.autograd.Function):
21    @staticmethod
22    def forward(ctx, foo):
23        return foo + foo
24
25    @staticmethod
26    def backward(ctx, grad_output):
27        return grad_output
28
29
30class CustomFunc3(torch.autograd.Function):
31    # Test there is graph break in forward function
32    @staticmethod
33    def forward(ctx, foo):
34        result = foo + foo
35        torch._dynamo.graph_break()
36        result = result + foo
37        ctx.save_for_backward(result)
38        return result
39
40    @staticmethod
41    def backward(ctx, grad_output):
42        (result,) = ctx.saved_tensors
43        return grad_output * math.sqrt(result.numel())
44
45
46class Module1(torch.nn.Module):
47    def forward(self, foo):
48        return CustomFunc1().apply(foo)
49
50
51class Module2(torch.nn.Module):
52    def __init__(self) -> None:
53        super().__init__()
54        self.fn = CustomFunc1.apply
55
56    def forward(self, foo):
57        return self.fn(foo)
58
59
60class Module3(torch.nn.Module):
61    def forward(self, foo):
62        return CustomFunc1().apply(foo)
63
64
65class Module4(torch.nn.Module):
66    def __init__(self) -> None:
67        super().__init__()
68        self.fn = CustomFunc1.apply
69
70    def forward(self, foo):
71        return self.fn(foo)
72
73
74class Module5(torch.nn.Module):
75    def forward(self, foo):
76        return CustomFunc3().apply(foo)
77
78
79class Module6(torch.nn.Module):
80    def __init__(self) -> None:
81        super().__init__()
82        self.fn = CustomFunc3.apply
83
84    def forward(self, foo):
85        return self.fn(foo)
86
87
88class LinearFunction(torch.autograd.Function):
89    # Note that forward, setup_context, and backward are @staticmethods
90    @staticmethod
91    def forward(input, weight, bias):
92        output = input.mm(weight.t())
93        if bias is not None:
94            output += bias.unsqueeze(0).expand_as(output)
95        return output
96
97    @staticmethod
98    # inputs is a Tuple of all of the inputs passed to forward.
99    # output is the output of the forward().
100    def setup_context(ctx, inputs, output):
101        input, weight, bias = inputs
102        ctx.save_for_backward(input, weight, bias)
103
104    # This function has only a single output, so it gets only one gradient
105    @staticmethod
106    def backward(ctx, grad_output):
107        input, weight, bias = ctx.saved_tensors
108        grad_input = grad_weight = grad_bias = None
109        if ctx.needs_input_grad[0]:
110            grad_input = grad_output.mm(weight)
111        if ctx.needs_input_grad[1]:
112            grad_weight = grad_output.t().mm(input)
113        if bias is not None and ctx.needs_input_grad[2]:
114            grad_bias = grad_output.sum(0)
115
116        return grad_input, grad_weight, grad_bias
117
118
119class ModuleLinear(torch.nn.Module):
120    def forward(self, input, weight, bias=None):
121        return LinearFunction.apply(input, weight, bias)
122
123
124class MaterializingGradFunction(torch.autograd.Function):
125    @staticmethod
126    def forward(ctx, x):
127        ctx.set_materialize_grads(False)
128        return x.clone(), x.clone()
129
130    @staticmethod
131    def backward(ctx, grad_out1, grad_out2):
132        return grad_out1, grad_out2
133
134
135class MaterializingGradModule(torch.nn.Module):
136    def forward(self, x):
137        return MaterializingGradFunction.apply(x)
138
139
140class CustomFuncBwdPrintGraphBreak(torch.autograd.Function):
141    @staticmethod
142    def forward(ctx, foo):
143        return torch.add(foo, foo)
144
145    @staticmethod
146    def backward(ctx, grad_output):
147        print("graph break!")
148        return grad_output
149
150
151class CustomFuncBwdPrintModule(torch.nn.Module):
152    def forward(self, x):
153        return CustomFuncBwdPrintGraphBreak.apply(x)
154
155
156class CustomFuncStrideBwd(torch.autograd.Function):
157    @staticmethod
158    def forward(ctx, foo):
159        return torch.add(foo, foo)
160
161    @staticmethod
162    def backward(ctx, grad_output):
163        return grad_output.stride()
164
165
166class CustomFuncStrideModule(torch.nn.Module):
167    def forward(self, x):
168        return CustomFuncStrideBwd.apply(x)
169
170
171class CustomFuncSaveForBwd(torch.autograd.Function):
172    @staticmethod
173    def forward(ctx, foo):
174        result = foo + foo
175        result = result + foo
176        ctx.save_for_backward(result)
177        return result
178
179    @staticmethod
180    def backward(ctx, grad_output):
181        (result,) = ctx.saved_tensors
182        return grad_output * math.sqrt(result.numel())
183
184
185class SaveForBwdModule(torch.nn.Module):
186    def forward(self, foo):
187        return CustomFuncSaveForBwd().apply(foo)
188
189
190class ContextSaveAndMark(torch.autograd.Function):
191    @staticmethod
192    def forward(ctx, x):
193        with torch.no_grad():
194            ctx.save_for_backward(x)
195            ctx.mark_non_differentiable(x)
196            return x
197
198    @staticmethod
199    def backward(ctx, grad_output):
200        return grad_output
201
202
203class ContextMarkAndSave(torch.autograd.Function):
204    @staticmethod
205    def forward(ctx, x):
206        with torch.no_grad():
207            ctx.mark_non_differentiable(x)
208            ctx.save_for_backward(x)
209            return x
210
211    @staticmethod
212    def backward(ctx, grad_output):
213        return grad_output
214
215
216class ModuleWithGradFunc(torch.nn.Module):
217    def __init__(self, func):
218        super().__init__()
219        self.f = func.apply
220
221    def forward(self, x):
222        return self.f(x)
223
224
225class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
226    # Sound behaviors, tested for working capture
227    def test_autograd_function_equivalence(self):
228        for grad in [True, False]:
229            for i in range(1, 5):
230                torch._dynamo.reset()
231                model = globals()[f"Module{i}"]()
232                opt_model = torch._dynamo.optimize("eager")(model)
233                self.assertTrue(
234                    torch.allclose(
235                        opt_model(torch.ones(2, 3, requires_grad=grad)),
236                        torch.tensor([2.0], requires_grad=grad),
237                    )
238                )
239
240    def test_autograd_function_has_graph_break(self):
241        for grad in [True, False]:
242            x = torch.randn(10, requires_grad=grad)
243            for model in [Module5(), Module6()]:
244                torch._dynamo.reset()
245                cnts = torch._dynamo.testing.CompileCounter()
246                opt_model = torch._dynamo.optimize(cnts)(model)
247                for _ in range(3):
248                    ref = model(x)
249                    res = opt_model(x)
250                    self.assertTrue(torch.allclose(ref, res))
251                self.assertEqual(cnts.frame_count, 2)
252
253    def test_linear_setup_context(self):
254        model = ModuleLinear()
255        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
256        input = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
257        weight = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
258        eager_result = model(input, weight)
259        optim_result = opt_model(input, weight)
260        self.assertEqual(optim_result, eager_result)
261
262    def test_materialize_grad(self):
263        model = MaterializingGradModule()
264        opt_model = torch._dynamo.optimize("eager")(model)
265        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
266        optim_result = opt_model(x)
267        eager_result = model(x)
268        self.assertEqual(optim_result, eager_result)
269
270    def test_print_in_bwd(self):
271        model = CustomFuncBwdPrintModule()
272        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
273        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
274        with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: print"):
275            opt_model(x)
276
277    def test_stride_in_bwd(self):
278        torch._dynamo.utils.counters.clear()
279        cnt = torch._dynamo.testing.CompileCounter()
280        model = CustomFuncStrideModule()
281        opt_model = torch.compile(backend=cnt)(model)
282        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
283        ref = model(x)
284        res = opt_model(x)
285
286        self.assertEqual(ref, res)
287        self.assertEqual(cnt.frame_count, 1)
288        # graph break: Illegal getattr invocation stride in strict mod.
289        self.assertEqual(
290            list(torch._dynamo.utils.counters["graph_break"].values()), [1]
291        )
292
293    def test_enum_arg(self):
294        from enum import Enum
295
296        class SomeEnum(Enum):
297            A = 0
298            B = 1
299
300        class Foo(torch.autograd.Function):
301            @staticmethod
302            def forward(ctx, x, e):
303                if e is SomeEnum.A:
304                    return x.sin()
305                else:
306                    return x.cos()
307
308            @staticmethod
309            def backward(ctx, g):
310                return g
311
312        @torch.compile(backend="eager", fullgraph=True)
313        def f(x, enum):
314            output = Foo.apply(
315                x,
316                enum,
317            )
318            return output
319
320        x = torch.tensor([[1.0, 2, 3], [4, 5, 6]], requires_grad=True)
321        y = f(x, SomeEnum.A)
322        self.assertEqual(y, x.sin())
323
324    def test_save_for_bwd(self):
325        model = SaveForBwdModule()
326        opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
327        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
328        opt_model(x)
329
330    def test_allow_in_graph(self):
331        torch._dynamo.utils.counters.clear()
332        cnt = torch._dynamo.testing.CompileCounter()
333
334        @torch._dynamo.allow_in_graph
335        class AllowInGraphFunc(torch.autograd.Function):
336            @staticmethod
337            def forward(ctx, x):
338                torch._dynamo.graph_break()
339                ctx.x0 = x.size(0)
340                return x * 2
341
342            @staticmethod
343            def backward(ctx, grad_out):
344                return grad_out * ctx.x0
345
346        @torch.compile(backend=cnt, fullgraph=True)
347        def fn(x):
348            return AllowInGraphFunc.apply(x)
349
350        x = torch.rand(2, 3, requires_grad=True)
351        result = fn(x)
352
353        self.assertEqual(result, AllowInGraphFunc.apply(x))
354        self.assertEqual(cnt.frame_count, 1)
355
356    def test_once_differentiable(self):
357        from torch.autograd.function import once_differentiable
358
359        torch._dynamo.utils.counters.clear()
360        cnt = torch._dynamo.testing.CompileCounter()
361
362        class ScaleGradient(torch.autograd.Function):
363            @staticmethod
364            def forward(ctx, x):
365                return x
366
367            @staticmethod
368            @once_differentiable
369            def backward(ctx, grad):
370                return grad * 0.5
371
372        @torch.compile(backend=cnt, fullgraph=True)
373        def fn(x):
374            return ScaleGradient.apply(x)
375
376        x = torch.randn(3, requires_grad=True)
377        result = fn(x)
378
379        self.assertEqual(result, ScaleGradient.apply(x))
380        self.assertEqual(cnt.frame_count, 1)
381
382    def test_classmethod(self):
383        class Shake(torch.autograd.Function):
384            @classmethod
385            def forward(cls, ctx, foo):
386                return foo + foo
387
388            @classmethod
389            def backward(cls, ctx, grad_output):
390                return grad_output
391
392        def f(x):
393            return Shake.apply(x)
394
395        x = torch.randn(4, 4, 4, 4, requires_grad=True)
396        opt_m = torch.compile(backend="eager")(f)
397        opt_m(x)
398
399    def test_function_context_save_and_mark(self):
400        mod = ModuleWithGradFunc(ContextSaveAndMark)
401        args, kwargs = ([torch.rand([1])], {})
402        before = mod(*args, **kwargs)
403
404        torch._dynamo.reset()
405        compiled_model = torch._dynamo.optimize("eager")(mod)
406        after = compiled_model(*args, **kwargs)
407        self.assertEqual(before, after)
408
409    def test_function_context_mark_and_save(self):
410        mod = ModuleWithGradFunc(ContextMarkAndSave)
411        args, kwargs = ([torch.rand([1])], {})
412        before = mod(*args, **kwargs)
413
414        torch._dynamo.reset()
415        compiled_model = torch._dynamo.optimize("eager")(mod)
416        after = compiled_model(*args, **kwargs)
417        self.assertEqual(before, after)
418
419    def test_multi_output(self):
420        torch._dynamo.utils.counters.clear()
421        cnt = torch._dynamo.testing.CompileCounter()
422
423        class Foo(torch.autograd.Function):
424            @staticmethod
425            def forward(ctx, x):
426                return x.clone(), x.clone()
427
428            @staticmethod
429            def backward(ctx, grad1, grad2):
430                return grad1 + grad2
431
432        @torch.compile(backend=cnt, fullgraph=True)
433        def f(x):
434            return Foo.apply(x)
435
436        x = torch.randn(3, requires_grad=True)
437        result = f(x)
438
439        self.assertEqual(result, Foo.apply(x))
440        self.assertEqual(cnt.frame_count, 1)
441
442    def test_amp_custom_fwd_bwd(self):
443        torch._dynamo.utils.counters.clear()
444        cnt = torch._dynamo.testing.CompileCounter()
445
446        class MyMM(torch.autograd.Function):
447            @staticmethod
448            @torch.amp.custom_fwd(device_type="cuda")
449            def forward(ctx, a, b):
450                ctx.save_for_backward(a, b)
451                return a.mm(b)
452
453            @staticmethod
454            @torch.amp.custom_bwd(device_type="cuda")
455            def backward(ctx, grad):
456                a, b = ctx.saved_tensors
457                return grad.mm(b.t()), a.t().mm(grad)
458
459        @torch.compile(backend=cnt, fullgraph=True)
460        def fn(a, b):
461            return MyMM.apply(a, b)
462
463        a = torch.randn([64, 64], dtype=torch.float32, requires_grad=True)
464        grad = a.clone()
465        res = fn(a, a)
466        res.backward(grad)
467
468        self.assertEqual(res, MyMM.apply(a, a))
469        self.assertEqual(cnt.frame_count, 1)
470
471    def test_set_materialize_grads_no_graph_break(self):
472        class MulY(torch.autograd.Function):
473            @staticmethod
474            def forward(ctx, x):
475                ctx.set_materialize_grads(True)
476                return x * 3
477
478            @staticmethod
479            def backward(ctx, grad_out):
480                return grad_out * 3
481
482        @torch.compile(backend="eager", fullgraph=True)
483        def f(x):
484            return MulY.apply(x)
485
486        x = torch.tensor(2.0, requires_grad=True)
487        result = f(x)
488        result.sum().backward()
489        self.assertEqual(result, MulY.apply(x))
490        self.assertEqual(x.grad, 3.0)
491
492    def test_user_defined_object_as_input(self):
493        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
494
495        @dataclass
496        class Weird:
497            x: int
498            b: torch.Tensor
499            c: torch.Tensor
500
501        class Foo(torch.autograd.Function):
502            @staticmethod
503            def forward(ctx, x: torch.Tensor, weird: Weird, z: torch.Tensor):
504                ctx.save_for_backward(weird.b, weird.c)
505                return weird.b * weird.c * x.clone()
506
507            @staticmethod
508            def backward(ctx, grad):
509                b, c = ctx.saved_tensors
510                return grad * b * c, None, grad * 2
511
512        @torch.compile(backend=cnt, fullgraph=True)
513        def f(x, weird, z):
514            return Foo.apply(x, weird, z)
515
516        x = torch.tensor(2.0, requires_grad=True)
517        weird = Weird(1.2, torch.tensor(2.5, requires_grad=True), torch.tensor(3.5))
518        z = torch.tensor(3.0, requires_grad=True)
519
520        result = f(x, weird, z)
521        result.sum().backward()
522
523        self.assertEqual(result, Foo.apply(x, weird, z))
524        self.assertEqual(x.grad, 2.5 * 3.5)
525        self.assertEqual(z.grad, 2.0)
526        self.assertEqual(weird.b.grad, None)
527
528        # check Dynamo captured graph is correct!
529        actual_graph = torch._dynamo.testing.normalize_gm(
530            cnt.graphs[0].print_readable(print_output=False)
531        )
532        self.assertExpectedInline(
533            actual_graph,
534            """\
535class GraphModule(torch.nn.Module):
536    def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: "f32[]"):
537        l_x_ = L_x_
538        l_z_ = L_z_
539        l_weird_b = L_weird_b
540        l_weird_c = L_weird_c
541
542        function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
543        fwd_body_0 = self.fwd_body_0
544        bwd_body_0 = self.bwd_body_0
545        autograd_function_apply: "f32[]" = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_z_, l_weird_b, l_weird_c, args_tensor_mask = [True, False, True], non_differentiable_idx = []);  fwd_body_0 = bwd_body_0 = l_x_ = l_z_ = l_weird_b = l_weird_c = None
546        return (autograd_function_apply,)
547
548    class fwd_body_0(torch.nn.Module):
549        def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
550            mul: "f32[]" = l_weird_b * l_weird_c
551            clone: "f32[]" = x.clone();  x = None
552            mul_1: "f32[]" = mul * clone;  mul = clone = None
553            return (mul_1, [l_weird_b, l_weird_c])
554
555    class bwd_body_0(torch.nn.Module):
556        def forward(self, ctx, grad: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"):
557            _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
558
559            mul: "f32[]" = grad * l_weird_b;  l_weird_b = None
560            mul_1: "f32[]" = mul * l_weird_c;  mul = l_weird_c = None
561            mul_2: "f32[]" = grad * 2;  grad = None
562
563            _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
564            return (mul_1, mul_2)
565""",
566        )
567
568    def test_tensor_list_as_input(self):
569        class Foo(torch.autograd.Function):
570            @staticmethod
571            def forward(ctx, x, tl):
572                ctx.save_for_backward(tl[0], tl[1])
573                return x.clone() * (tl[0] + tl[1])
574
575            @staticmethod
576            def backward(ctx, grad):
577                tl0, tl1 = ctx.saved_tensors
578                return grad * (tl0 + tl1), None
579
580        @torch.compile(backend="aot_eager", fullgraph=True)
581        def f(x, tl):
582            return Foo.apply(x, tl)
583
584        x = torch.tensor(2.0, requires_grad=True)
585        tl = [
586            torch.tensor(3.0, requires_grad=True),
587            torch.tensor(4.0, requires_grad=True),
588        ]
589
590        result = f(x, tl)
591        result.sum().backward()
592
593        self.assertEqual(result, Foo.apply(x, tl))
594        self.assertEqual(x.grad, 7.0)
595        self.assertEqual(tl[0].grad, None)
596        self.assertEqual(tl[1].grad, None)
597
598    def test_multiple_different_non_tensor_inputs(self):
599        @dataclass
600        class Weird:
601            x: int
602            b: torch.Tensor
603            c: torch.Tensor
604
605        class Foo(torch.autograd.Function):
606            @staticmethod
607            def forward(ctx, x, weird, z, tl):
608                ctx.save_for_backward(weird.b, weird.c, tl[0], tl[1])
609                return x.clone() * weird.b * weird.c * tl[0]
610
611            @staticmethod
612            def backward(ctx, grad):
613                b, c, tl0, _ = ctx.saved_tensors
614                return grad * b * c * tl0, None, grad * 2, None
615
616        @torch.compile(backend="aot_eager", fullgraph=True)
617        def f(x, weird, z, tl):
618            return Foo.apply(x, weird, z, tl)
619
620        x = torch.tensor(2.0, requires_grad=True)
621        weird = Weird(
622            1.2,
623            torch.tensor(2.5, requires_grad=True),
624            torch.tensor(3.5, requires_grad=True),
625        )
626        z = torch.tensor(3.0, requires_grad=True)
627        tl = [
628            torch.tensor(0.5, requires_grad=True),
629            torch.tensor(0.6, requires_grad=True),
630        ]
631
632        result = f(x, weird, z, tl)
633        result.sum().backward()
634
635        self.assertEqual(result, Foo.apply(x, weird, z, tl))
636        self.assertEqual(x.grad, 2.5 * 3.5 * 0.5)
637        self.assertEqual(z.grad, 2.0)
638        self.assertEqual(weird.b.grad, None)
639        self.assertEqual(weird.c.grad, None)
640        self.assertEqual(tl[0].grad, None)
641        self.assertEqual(tl[1].grad, None)
642
643    def test_backward_returns_none_for_tensor_input(self):
644        class Foo(torch.autograd.Function):
645            @staticmethod
646            def forward(ctx, x, y):
647                ctx.save_for_backward(y)
648                return x.clone() * y
649
650            @staticmethod
651            def backward(ctx, grad):
652                (y,) = ctx.saved_tensors
653                return grad * y, None
654
655        @torch.compile(backend="aot_eager", fullgraph=True)
656        def f(x, y):
657            return Foo.apply(x, y)
658
659        x = torch.tensor(2.0, requires_grad=True)
660        y = torch.tensor(3.0, requires_grad=True)
661
662        result = f(x, y)
663        result.sum().backward()
664
665        self.assertEqual(result, Foo.apply(x, y))
666        self.assertEqual(x.grad, 3.0)
667        self.assertEqual(y.grad, None)
668
669    def test_function_with_bound_free_variable(self):
670        class LowerBound(torch.autograd.Function):
671            @staticmethod
672            def forward(ctx, inputs, bound):
673                ctx.save_for_backward(inputs, inputs.new_ones(1) * bound)
674                return inputs.clamp(min=bound)
675
676            @staticmethod
677            def backward(ctx, grad_output):
678                inputs, bound = ctx.saved_tensors
679                return (inputs >= bound) * grad_output, None
680
681        class MyMod(torch.nn.Module):
682            def __init__(self) -> None:
683                super().__init__()
684                self.gamma = torch.nn.Parameter(torch.rand([4, 128, 32, 32]))
685
686            def forward(self, x):
687                gamma = LowerBound.apply(self.gamma, 1)
688                return x + gamma
689
690        mod = MyMod()
691        args, kwargs = ([torch.rand([4, 128, 32, 32])], {})
692        before = mod(*args, **kwargs)
693
694        compiled_model = torch._dynamo.optimize("eager")(mod)
695        after = compiled_model(*args, **kwargs)
696        self.assertEqual(before, after)
697
698    # I pulled all of these test cases from test_autograd.py
699    # In the future, we should make the Dynamo test suite actually
700    # run on test_autograd.py (it's disabled right now) and delete these.
701    def test_smoke_from_test_autograd(self):
702        def mult1(x):
703            return x.prod(dim=-1).prod(dim=-1)
704
705        class Mult(torch.autograd.Function):
706            @staticmethod
707            def forward(ctx, x):
708                y = mult1(x)
709                ctx.save_for_backward(x, y)
710                return y
711
712            @staticmethod
713            def backward(ctx, grad_output):
714                x, y = ctx.saved_tensors
715                return (grad_output * y)[:, None, None] / x
716
717        mult2 = Mult.apply
718
719        class Double(torch.autograd.Function):
720            @staticmethod
721            def forward(ctx, x):
722                y = x**2
723                ctx.save_for_backward(x, y)
724                return y
725
726            @staticmethod
727            def backward(ctx, grad_output):
728                x, _ = ctx.saved_tensors
729                return grad_output * 2 * x
730
731        # this is equivalent, but uses the output of .forward() in .backward()
732        class Double2(Double):
733            @staticmethod
734            def backward(ctx, grad_output):
735                x, y = ctx.saved_tensors
736                return grad_output * 2 * y / x
737
738        double = Double.apply
739        double2 = Double2.apply
740
741        class Identity(torch.autograd.Function):
742            @staticmethod
743            def forward(ctx, a, b):
744                return a, a + b
745
746            @staticmethod
747            def backward(ctx, grad_a, grad_b):
748                return grad_a + grad_b, grad_b
749
750        class MyFunc2(torch.autograd.Function):
751            @staticmethod
752            def forward(ctx, inp):
753                return inp.clone()
754
755            @staticmethod
756            def backward(ctx, gO):
757                return torch.tensor(float("nan")).expand(10, 10)
758
759        def run_fn(a):
760            out = MyFunc2.apply(a)
761            return out.sum()
762
763        class MyFn(torch.autograd.Function):
764            @staticmethod
765            def forward(ctx, inp):
766                return inp.view_as(inp)
767
768            @staticmethod
769            def backward(ctx, grad):
770                return grad
771
772        class MyAdder(torch.autograd.Function):
773            @staticmethod
774            def forward(ctx, a, b):
775                a.add_(b)
776                ctx.mark_dirty(a)
777                return a
778
779            @staticmethod
780            def backward(ctx, grad):
781                return grad, grad
782
783        class InplaceMul(torch.autograd.Function):
784            @staticmethod
785            def forward(ctx, x):
786                result = x.mul_(2)
787                ctx.mark_dirty(result)
788                return result
789
790            @staticmethod
791            def backward(ctx, grad_output):
792                pass
793
794            @staticmethod
795            def jvp(ctx, x_t):
796                if jvp_err:  # noqa: F821
797                    return x_t
798                else:
799                    return x_t.mul_(2)
800
801        class MyFn2(torch.autograd.Function):
802            @staticmethod
803            def forward(ctx, x, y):
804                return x + y, x
805
806            @staticmethod
807            def vjp(ctx, gO1, gO2):
808                return gO1 + gO2, gO1
809
810            @staticmethod
811            def jvp(ctx, x_t, y_t):
812                return x_t + y_t, fn(x_t)  # noqa: F821
813
814        class MyFn3(torch.autograd.Function):
815            @staticmethod
816            def forward(ctx, inp, inplace):
817                view = inp.clone()[:3]
818                if inplace:
819                    view += 2
820                return view
821
822            @staticmethod
823            def backward(ctx, grad):
824                return grad, None
825
826        def test():
827            x = torch.ones(2, 4, 4).requires_grad_()
828            mult2(x)
829
830            x = torch.tensor(2).double().requires_grad_()
831            double(x)
832            double2(x)
833
834            x = torch.randn(5, 5, requires_grad=True)
835            y = torch.randn(5, 5, requires_grad=True)
836            q, p = Identity.apply(x, y)
837
838            a = torch.rand(1, 2)
839            b = torch.rand(1, requires_grad=True)
840            view_a = MyFn.apply(a)
841
842            a = torch.ones(2, requires_grad=True)
843            b = torch.ones(2, requires_grad=True)
844            c = MyAdder.apply(a.clone(), b)
845            c.sum().backward()
846
847            z = torch.tensor(1.0, requires_grad=True)
848            x = z.clone()
849            y = InplaceMul.apply(x)
850
851            a = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
852            b = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
853            c = torch.tensor(1.0, dtype=torch.double)
854            d = torch.tensor(1.0, dtype=torch.double)
855            MyFn2.apply(a, b)
856            MyFn2.apply(c, d)
857
858            base = torch.rand(10, requires_grad=True)
859            foo = MyFn3.apply(base, False)
860
861        test()
862        opt_test = torch._dynamo.optimize("eager")(test)
863        opt_test()
864
865    def test_tensor_subclass_intermediary_input(self):
866        class FooTensor(torch.Tensor):
867            @staticmethod
868            def __new__(cls, data, config, scale):
869                self = torch.Tensor._make_wrapper_subclass(
870                    cls,
871                    config[0],
872                    strides=config[1],
873                    storage_offset=config[2],
874                    dtype=config[3],
875                    layout=config[4],
876                    requires_grad=config[5],
877                    device=data.device,
878                )
879                self._data = data
880                self._config = config
881                self._scale = scale
882                return self
883
884            def __repr__(self):
885                return "FooTensor"
886
887            def __tensor_flatten__(self):
888                return ("_data",), (
889                    self._config,
890                    self._scale,
891                )
892
893            @staticmethod
894            def __tensor_unflatten__(tensors, metadatas, outer_size, outer_stride):
895                return FooTensor(tensors["_data"], metadatas[0], metadatas[1])
896
897            @classmethod
898            def __torch_dispatch__(cls, func, types, args, kwargs=None):
899                # handling clone and view is so dynamo fakefication passes, it's not
900                # intended to be handling user code
901                if func == torch.ops.aten.clone.default:
902                    return FooTensor(
903                        args[0]._data.clone(), args[0]._config, args[0]._scale
904                    )
905                elif func == torch.ops.aten.view.default:
906                    new_data = args[0]._data.view(*args[1:])
907                    return FooTensor(new_data, args[0]._config, args[0]._scale)
908
909                raise NotImplementedError
910
911        class foo_autograd_fn(torch.autograd.Function):
912            @staticmethod
913            def forward(ctx, x):
914                # access some data from `x`, where `x` is a tensor subclass
915                x2 = x._data + 1.0
916                # create and return a tensor subclass from within a torch.autograd.Function
917                x3 = FooTensor(x2, x._config, x._scale)
918                return x3._data
919
920            @staticmethod
921            def backward(ctx, g):
922                return g
923
924        x_ref = torch.randn(4, 4).requires_grad_(True)
925        x = copy.deepcopy(x_ref)
926        scale = torch.tensor(1.0)
927        # Weird that this is needed, but not having this breaks a lot of things
928        torch._dynamo.allow_in_graph(FooTensor)
929
930        def foo(x, scale):
931            config = (
932                x.size(),
933                x.stride(),
934                x.storage_offset(),
935                x.dtype,
936                x.layout,
937                x.requires_grad,
938            )
939            x = FooTensor(x, config, scale)
940            x = foo_autograd_fn.apply(x)
941            return x
942
943        y_ref = foo(x_ref, scale)
944        y_ref.sum().backward()
945
946        foo_opt = torch.compile(foo, backend="eager")
947        y = foo_opt(x, scale)
948        y.sum().backward()
949
950        self.assertEqual(y, y_ref)
951        self.assertEqual(x.grad, x_ref.grad)
952
953    def test_smuggle_symint_issue_111031(self):
954        from torch.autograd import Function
955
956        class Foo(Function):
957            @staticmethod
958            def forward(ctx, x):
959                ctx.x0 = x.size(0)
960                return x * 2
961
962            @staticmethod
963            def backward(ctx, grad_out):
964                return grad_out * ctx.x0
965
966        cnts = torch._dynamo.testing.CompileCounter()
967
968        @torch.compile(backend=cnts, fullgraph=True, dynamic=True)
969        def foo(x):
970            return Foo.apply(x)
971
972        foo(torch.randn(2, requires_grad=True))
973        self.assertEqual(cnts.frame_count, 1)
974
975    def test_needs_input_grad(self):
976        cnt = torch._dynamo.testing.CompileCounter()
977
978        class NeedsInputGradFunc(torch.autograd.Function):
979            @staticmethod
980            def forward(ctx, foo):
981                result = foo + foo
982                ctx.save_for_backward(result)
983                return result
984
985            @staticmethod
986            @torch.compile(backend=cnt, fullgraph=True)
987            def backward(ctx, grad_output):
988                (result,) = ctx.saved_tensors
989                if ctx.needs_input_grad[0]:
990                    return grad_output * result.sin()
991                return None
992
993        x = torch.randn(10, requires_grad=True)
994        NeedsInputGradFunc.apply(x).sum().backward()
995        self.assertEqual(x.grad.shape, x.shape)
996        self.assertEqual(cnt.frame_count, 1)
997        self.assertEqual(cnt.op_count, 2)
998
999    def test_repeated_save_for_backward_calls(self):
1000        from torch.autograd import Function
1001
1002        class Foo(Function):
1003            @staticmethod
1004            def forward(ctx, x, y):
1005                ctx.save_for_backward(x)
1006                ctx.save_for_backward(x, y)
1007                return x * y
1008
1009            @staticmethod
1010            def backward(ctx, grad_out):
1011                x, y = ctx.saved_tensors
1012                return grad_out * x, grad_out * y
1013
1014        cnts = torch._dynamo.testing.CompileCounter()
1015
1016        def foo(x, y):
1017            return Foo.apply(x, y)
1018
1019        x_ref = torch.randn(2, requires_grad=True)
1020        y_ref = torch.randn(2, requires_grad=True)
1021        x_test = x_ref.clone().detach().requires_grad_()
1022        y_test = y_ref.clone().detach().requires_grad_()
1023
1024        out_ref = foo(x_ref, y_ref)
1025        out_ref.sum().backward()
1026
1027        out_test = torch.compile(foo, backend=cnts)(x_test, y_test)
1028        out_test.sum().backward()
1029
1030        self.assertEqual(cnts.frame_count, 1)
1031        self.assertEqual(out_ref, out_test)
1032        self.assertEqual(x_ref.grad, x_test.grad)
1033        self.assertEqual(y_ref.grad, y_test.grad)
1034
1035    def test_smuggle_tensor_and_complex_structures(self):
1036        from torch.autograd import Function
1037
1038        class Foo(Function):
1039            @staticmethod
1040            def forward(ctx, x):
1041                ctx.x0 = x
1042                ctx.x1 = [1, 2, 3]
1043                return x * 2
1044
1045            @staticmethod
1046            def backward(ctx, grad_out):
1047                x0mul = grad_out * ctx.x0
1048                for i in ctx.x1:
1049                    x0mul = (x0mul * i) + x0mul
1050                return x0mul
1051
1052        cnts = torch._dynamo.testing.CompileCounter()
1053
1054        @torch.compile(backend=cnts, fullgraph=True, dynamic=True)
1055        def foo(x):
1056            return Foo.apply(x)
1057
1058        foo(torch.randn(2, requires_grad=True))
1059        self.assertEqual(cnts.frame_count, 1)
1060
1061    def test_mark_non_differentiable(self):
1062        cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
1063        from torch.autograd import Function
1064
1065        class MyFunction(Function):
1066            @staticmethod
1067            def forward(ctx, x, y):
1068                out1 = x.sin()
1069                out2 = y * 2
1070                ctx.mark_non_differentiable(out2)
1071                return out1, out2
1072
1073            @staticmethod
1074            def backward(ctx, grad1, grad2):
1075                return grad1.cos(), grad2 * 0.0
1076
1077        @torch.compile(backend=cnt, fullgraph=True)
1078        def fn(x, y):
1079            return MyFunction.apply(x, y)
1080
1081        x = torch.tensor(10.0, requires_grad=True)
1082        y = torch.tensor(20.0, requires_grad=True)
1083        ref1, ref2 = MyFunction.apply(x, y)
1084        res1, res2 = fn(x, y)
1085        self.assertEqual(ref1, res1)
1086        self.assertEqual(ref2, res2)
1087        # Ensure out1 requires gradients, out2 does not.
1088        self.assertTrue(ref1.requires_grad)
1089        self.assertTrue(res1.requires_grad)
1090        self.assertFalse(ref2.requires_grad)
1091        self.assertFalse(res2.requires_grad)
1092        res1.sum().backward()
1093
1094        # check Dynamo captured graph is correct!
1095        actual_graph = torch._dynamo.testing.normalize_gm(
1096            cnt.graphs[0].print_readable(print_output=False)
1097        )
1098        self.assertExpectedInline(
1099            actual_graph,
1100            """\
1101class GraphModule(torch.nn.Module):
1102    def forward(self, L_x_: "f32[]", L_y_: "f32[]"):
1103        l_x_ = L_x_
1104        l_y_ = L_y_
1105
1106        function_ctx = torch.autograd.function.FunctionCtx();  function_ctx = None
1107        fwd_body_0 = self.fwd_body_0
1108        bwd_body_0 = self.bwd_body_0
1109        autograd_function_apply = torch.ops.higher_order.autograd_function_apply(fwd_body_0, bwd_body_0, l_x_, l_y_, args_tensor_mask = [True, True], non_differentiable_idx = [1]);  fwd_body_0 = bwd_body_0 = l_x_ = l_y_ = None
1110        getitem: "f32[]" = autograd_function_apply[0]
1111        getitem_1: "f32[]" = autograd_function_apply[1];  autograd_function_apply = None
1112        return (getitem, getitem_1)
1113
1114    class fwd_body_0(torch.nn.Module):
1115        def forward(self, ctx, x: "f32[]", y: "f32[]"):
1116            out1: "f32[]" = x.sin();  x = None
1117
1118            out2: "f32[]" = y * 2;  y = None
1119            return ((out1, out2), [])
1120
1121    class bwd_body_0(torch.nn.Module):
1122        def forward(self, ctx, grad1: "f32[]", grad2: "f32[]"):
1123            _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None
1124
1125            cos: "f32[]" = grad1.cos();  grad1 = None
1126            mul: "f32[]" = grad2 * 0.0;  grad2 = None
1127
1128            _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
1129            return (cos, mul)
1130""",
1131        )
1132
1133    def test_mark_multi_output_non_differentiable(self):
1134        from torch.autograd import Function
1135
1136        class MyFunction(Function):
1137            @staticmethod
1138            def forward(ctx, x, y, z):
1139                out1 = x.sin()
1140                out2 = y * 2
1141                out3 = z + 3
1142                ctx.mark_non_differentiable(out2, out3)
1143                return out1, out2, out3
1144
1145            @staticmethod
1146            def backward(ctx, grad1, grad2, grad3):
1147                return grad1.cos(), grad2, grad3
1148
1149        @torch.compile(backend="aot_eager", fullgraph=True)
1150        def fn(x, y, z):
1151            return MyFunction.apply(x, y, z)
1152
1153        x = torch.tensor(10.0, requires_grad=True)
1154        y = torch.tensor(20.0, requires_grad=True)
1155        z = torch.tensor(30.0, requires_grad=True)
1156        ref1, ref2, ref3 = MyFunction.apply(x, y, z)
1157        res1, res2, res3 = fn(x, y, z)
1158        self.assertEqual(ref1, res1)
1159        self.assertEqual(ref2, res2)
1160        self.assertEqual(ref3, res3)
1161        # Ensure out1 requires gradients, out2 does not.
1162        self.assertTrue(ref1.requires_grad)
1163        self.assertTrue(res1.requires_grad)
1164        self.assertFalse(ref2.requires_grad)
1165        self.assertFalse(res2.requires_grad)
1166        self.assertFalse(ref3.requires_grad)
1167        self.assertFalse(res3.requires_grad)
1168        res1.sum().backward()
1169
1170    def test_default_values(self):
1171        from torch.autograd import Function
1172
1173        class Foo(Function):
1174            @staticmethod
1175            def forward(ctx, x, alpha=0.99):
1176                return x
1177
1178            @staticmethod
1179            def backward(ctx, grad_out):
1180                return grad_out
1181
1182        @torch.compile
1183        def foo(x):
1184            return Foo.apply(x)
1185
1186        # Make sure guards for default values do not crash
1187        foo(torch.randn(2))
1188        foo(torch.randn(2, requires_grad=True))
1189
1190    def test_tuple_arg(self):
1191        cnt = torch._dynamo.testing.CompileCounter()
1192
1193        class TupleArgFunc(torch.autograd.Function):
1194            @staticmethod
1195            def forward(ctx, x, shape):
1196                ctx.save_for_backward(torch.randn(shape))
1197                return x + 1
1198
1199            @staticmethod
1200            def backward(ctx, grad_output):
1201                (result,) = ctx.saved_tensors
1202                return result, None
1203
1204        @torch.compile(backend=cnt, fullgraph=True)
1205        def fn():
1206            return TupleArgFunc.apply(x, shape)
1207
1208        shape = (10, 10)
1209        x = torch.randn(shape, requires_grad=True)
1210        out = fn()
1211        out.sum().backward()
1212        self.assertEqual(out, x + 1)
1213        self.assertEqual(x.grad.shape, shape)
1214        self.assertEqual(cnt.frame_count, 1)
1215        self.assertEqual(cnt.op_count, 2)
1216
1217    @requires_cuda
1218    def test_triton_kernel_basic(self):
1219        class Add(torch.autograd.Function):
1220            @staticmethod
1221            def forward(ctx, x, y):
1222                ctx.save_for_backward(x, y)
1223                output = torch.zeros_like(x)
1224                n_elements = output.numel()
1225                grid = lambda meta: (  # noqa: E731
1226                    triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
1227                )
1228                add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1229                return output
1230
1231            @staticmethod
1232            def backward(ctx, grad_output):
1233                x, y = ctx.saved_tensors
1234                return x * grad_output, y * grad_output
1235
1236        @torch.compile(fullgraph=True, backend="inductor")
1237        def f(x, y):
1238            z = Add.apply(x, y)
1239            return z
1240
1241        x = torch.randn(10, device="cuda", requires_grad=True)
1242        y = torch.randn(10, device="cuda", requires_grad=True)
1243        z = f(x, y)
1244        loss = z.sum()
1245        loss.backward()
1246        self.assertEqual(x + y, z)
1247
1248    @requires_cuda
1249    def test_triton_kernel_multiple_out(self):
1250        class Add(torch.autograd.Function):
1251            @staticmethod
1252            def forward(ctx, x, y):
1253                ctx.save_for_backward(x, y)
1254                ctx.t1 = x
1255                ctx.t2 = y
1256                output = torch.zeros_like(x)
1257                n_elements = output.numel()
1258                grid = lambda meta: (  # noqa: E731
1259                    triton.cdiv(n_elements, meta["BLOCK_SIZE"]),
1260                )
1261                add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
1262                return output, x
1263
1264            @staticmethod
1265            def backward(ctx, grad_output, old_x):
1266                x, y = ctx.saved_tensors
1267                x1 = ctx.t1
1268                y1 = ctx.t2
1269                return old_x * x * x1 * grad_output, y * y1 * grad_output
1270
1271        @torch.compile(fullgraph=True, backend="inductor")
1272        def f(x, y):
1273            z = Add.apply(x, y)
1274            return z
1275
1276        x = torch.randn(10, device="cuda", requires_grad=True)
1277        y = torch.randn(10, device="cuda", requires_grad=True)
1278        z, _ = f(x, y)
1279        loss = z.sum()
1280        loss.backward()
1281        self.assertEqual(x + y, z)
1282
1283
1284if __name__ == "__main__":
1285    from torch._dynamo.test_case import run_tests
1286
1287    run_tests()
1288