1# mypy: allow-untyped-defs 2from typing import List 3 4import torch 5 6class ListUnpack(torch.nn.Module): 7 """ 8 Lists are treated as static construct, therefore unpacking should be 9 erased after tracing. 10 """ 11 12 def forward(self, args: List[torch.Tensor]): 13 """ 14 Lists are treated as static construct, therefore unpacking should be 15 erased after tracing. 16 """ 17 x, *y = args 18 return x + y[0] 19 20example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],) 21tags = {"python.control-flow", "python.data-structure"} 22model = ListUnpack() 23