1# mypy: allow-untyped-defs 2from enum import Enum 3 4import torch 5 6class Animal(Enum): 7 COW = "moo" 8 9class SpecializedAttribute(torch.nn.Module): 10 """ 11 Model attributes are specialized. 12 """ 13 14 def __init__(self) -> None: 15 super().__init__() 16 self.a = "moo" 17 self.b = 4 18 19 def forward(self, x): 20 if self.a == Animal.COW.value: 21 return x * x + self.b 22 else: 23 raise ValueError("bad") 24 25example_args = (torch.randn(3, 2),) 26model = SpecializedAttribute() 27