xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/list_contains.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class ListContains(torch.nn.Module):
5    """
6    List containment relation can be checked on a dynamic shape or constants.
7    """
8
9    def forward(self, x):
10        assert x.size(-1) in [6, 2]
11        assert x.size(0) not in [4, 5, 6]
12        assert "monkey" not in ["cow", "pig"]
13        return x + x
14
15example_args = (torch.randn(3, 2),)
16tags = {"torch.dynamic-shape", "python.data-structure", "python.assert"}
17model = ListContains()
18