xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/class_method.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class ClassMethod(torch.nn.Module):
5    """
6    Class methods are inlined during tracing.
7    """
8
9    @classmethod
10    def method(cls, x):
11        return x + 1
12
13    def __init__(self) -> None:
14        super().__init__()
15        self.linear = torch.nn.Linear(4, 2)
16
17    def forward(self, x):
18        x = self.linear(x)
19        return self.method(x) * self.__class__.method(x) * type(self).method(x)
20
21example_args = (torch.randn(3, 4),)
22model = ClassMethod()
23