xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/autograd_function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class MyAutogradFunction(torch.autograd.Function):
5    @staticmethod
6    def forward(ctx, x):
7        return x.clone()
8
9    @staticmethod
10    def backward(ctx, grad_output):
11        return grad_output + 1
12
13class AutogradFunction(torch.nn.Module):
14    """
15    TorchDynamo does not keep track of backward() on autograd functions. We recommend to
16    use `allow_in_graph` to mitigate this problem.
17    """
18
19    def forward(self, x):
20        return MyAutogradFunction.apply(x)
21
22example_args = (torch.randn(3, 2),)
23model = AutogradFunction()
24