1# mypy: allow-untyped-defs 2import torch 3 4class MyAutogradFunction(torch.autograd.Function): 5 @staticmethod 6 def forward(ctx, x): 7 return x.clone() 8 9 @staticmethod 10 def backward(ctx, grad_output): 11 return grad_output + 1 12 13class AutogradFunction(torch.nn.Module): 14 """ 15 TorchDynamo does not keep track of backward() on autograd functions. We recommend to 16 use `allow_in_graph` to mitigate this problem. 17 """ 18 19 def forward(self, x): 20 return MyAutogradFunction.apply(x) 21 22example_args = (torch.randn(3, 2),) 23model = AutogradFunction() 24