xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/list_unpack.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List
3
4import torch
5
6class ListUnpack(torch.nn.Module):
7    """
8    Lists are treated as static construct, therefore unpacking should be
9    erased after tracing.
10    """
11
12    def forward(self, args: List[torch.Tensor]):
13        """
14        Lists are treated as static construct, therefore unpacking should be
15        erased after tracing.
16        """
17        x, *y = args
18        return x + y[0]
19
20example_args = ([torch.randn(3, 2), torch.tensor(4), torch.tensor(5)],)
21tags = {"python.control-flow", "python.data-structure"}
22model = ListUnpack()
23