xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/specialized_attribute.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from enum import Enum
3
4import torch
5
6class Animal(Enum):
7    COW = "moo"
8
9class SpecializedAttribute(torch.nn.Module):
10    """
11    Model attributes are specialized.
12    """
13
14    def __init__(self) -> None:
15        super().__init__()
16        self.a = "moo"
17        self.b = 4
18
19    def forward(self, x):
20        if self.a == Animal.COW.value:
21            return x * x + self.b
22        else:
23            raise ValueError("bad")
24
25example_args = (torch.randn(3, 2),)
26model = SpecializedAttribute()
27