xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/decorator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3
4import torch
5
6def test_decorator(func):
7    @functools.wraps(func)
8    def wrapper(*args, **kwargs):
9        return func(*args, **kwargs) + 1
10
11    return wrapper
12
13class Decorator(torch.nn.Module):
14    """
15    Decorators calls are inlined into the exported function during tracing.
16    """
17
18    @test_decorator
19    def forward(self, x, y):
20        return x + y
21
22example_args = (torch.randn(3, 2), torch.randn(3, 2))
23model = Decorator()
24