1# mypy: allow-untyped-defs 2import functools 3 4import torch 5 6def test_decorator(func): 7 @functools.wraps(func) 8 def wrapper(*args, **kwargs): 9 return func(*args, **kwargs) + 1 10 11 return wrapper 12 13class Decorator(torch.nn.Module): 14 """ 15 Decorators calls are inlined into the exported function during tracing. 16 """ 17 18 @test_decorator 19 def forward(self, x, y): 20 return x + y 21 22example_args = (torch.randn(3, 2), torch.randn(3, 2)) 23model = Decorator() 24