1from typing import Any 2 3import torch 4 5 6@torch.jit.script 7class MyScriptClass: 8 """Intended to be scripted.""" 9 10 def __init__(self, x): 11 self.foo = x 12 13 def set_foo(self, x): 14 self.foo = x 15 16 17@torch.jit.script 18def uses_script_class(x): 19 """Intended to be scripted.""" 20 foo = MyScriptClass(x) 21 return foo.foo 22 23 24class IdListFeature: 25 def __init__(self) -> None: 26 self.id_list = torch.ones(1, 1) 27 28 def returns_self(self) -> "IdListFeature": 29 return IdListFeature() 30 31 32class UsesIdListFeature(torch.nn.Module): 33 def forward(self, feature: Any): 34 if isinstance(feature, IdListFeature): 35 return feature.id_list 36 else: 37 return feature 38