xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/static_if.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import torch
3
4class StaticIf(torch.nn.Module):
5    """
6    `if` statement with static predicate value should be traced through with the
7    taken branch.
8    """
9
10    def forward(self, x):
11        if len(x.shape) == 3:
12            return x + torch.ones(1, 1, 1)
13
14        return x
15
16example_args = (torch.randn(3, 2, 2),)
17tags = {"python.control-flow"}
18model = StaticIf()
19